Skip to content

Commit

Permalink
Refactor unwrap_proxy() for proxy tensor tracing. (#104667)
Browse files Browse the repository at this point in the history
Test Plan: CI

Differential Revision: D47241815

Pull Request resolved: #104667
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Jul 6, 2023
1 parent d0e5c68 commit df281bf
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 21 deletions.
4 changes: 1 addition & 3 deletions functorch/experimental/_cond.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from functools import partial
import torch
from torch.multiprocessing.reductions import StorageWeakRef

Expand All @@ -14,7 +13,6 @@
ProxyTorchDispatchMode,
make_fx,
track_tensor_tree,
unwrap_proxy,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import (
Expand Down Expand Up @@ -94,7 +92,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):

args = (pred, true_graph, false_graph, operands)

proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)

out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")
Expand Down
5 changes: 1 addition & 4 deletions functorch/experimental/_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
Expand All @@ -13,7 +11,6 @@
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
unwrap_proxy,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
Expand Down Expand Up @@ -184,7 +181,7 @@ def expand_tensor(t):

proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, num_mapped, *args)
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="map_impl")
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer)
Expand Down
10 changes: 4 additions & 6 deletions torch/_prims/rng_prims.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Optional, Tuple

import torch
Expand All @@ -14,7 +13,6 @@
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
unwrap_proxy,
)
from torch.types import _device, _dtype
from torch.utils._python_dispatch import (
Expand Down Expand Up @@ -203,8 +201,8 @@ def impl_proxy_dispatch_mode(op, *args, **kwargs):
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
out = impl_fake_tensor_mode(op, *args, **kwargs)
proxy_args = pytree.tree_map(partial(unwrap_proxy, mode), (op, *args))
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
Expand Down Expand Up @@ -254,9 +252,9 @@ def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
partial(unwrap_proxy, mode), (rng_state, op, *args)
mode.tracer.unwrap_proxy, (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(partial(unwrap_proxy, mode), kwargs)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
Expand Down
16 changes: 8 additions & 8 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,6 @@ def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
def snapshot_fake(val):
return val.detach()

def unwrap_proxy(proxy_mode, e):
if isinstance(e, torch.Tensor):
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return get_proxy_slot(e.node, proxy_mode.tracer, e, lambda e: e())
else:
return e

def extract_val(val):
if isinstance(val, FakeTensor):
return snapshot_fake(val)
Expand Down Expand Up @@ -458,6 +450,14 @@ def create_arg(self, a: Any):
return a.node.constant
return super().create_arg(a)

def unwrap_proxy(self, e):
if isinstance(e, torch.Tensor):
return get_proxy_slot(e, self, e, lambda e: e.proxy)
elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return get_proxy_slot(e.node, self, e, lambda e: e())
else:
return e


@torch._disable_dynamo
def dispatch_trace(
Expand Down

0 comments on commit df281bf

Please sign in to comment.