Skip to content

Commit

Permalink
[HigherOrderOp] Should automatically pop modes (#109157)
Browse files Browse the repository at this point in the history
Fixes #108282

Pull Request resolved: #109157
Approved by: https://github.com/zou3519
  • Loading branch information
yanboliang authored and pytorchmergebot committed Sep 18, 2023
1 parent 73ac814 commit 8a567bb
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 130 deletions.
22 changes: 8 additions & 14 deletions functorch/experimental/_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
track_tensor_tree,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)


# TODO: We add this to prevent dymamo from tracing into map_wrapper,
Expand Down Expand Up @@ -311,19 +307,17 @@ def map_autograd(f, num_mapped_args, *args):


@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
mode = _get_current_dispatch_mode()
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
else:
return map_impl(f, num_mapped, *args)
def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
else:
return map_impl(f, num_mapped, *args)


@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(f, num_mapped, *args):
return map_dense(f, num_mapped, *args)
def map_fake_tensor_mode(mode, f, num_mapped, *args):
with mode:
return map_dense(f, num_mapped, *args)


@map_impl.py_impl(DispatchKey.Functionalize)
Expand Down
28 changes: 11 additions & 17 deletions torch/_export/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)


_export_tracepoint = HigherOrderOperator("_export_tracepoint")
Expand All @@ -26,22 +22,20 @@


@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
def export_tracepoint_dispatch_mode(*args, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if not mode.enable_tracing:
return _export_tracepoint(*args, **kwargs)
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
proxy = mode.tracer.create_proxy(
"call_function", _export_tracepoint, p_args, p_kwargs
)
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
if not mode.enable_tracing:
return _export_tracepoint(*args, **kwargs)
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
proxy = mode.tracer.create_proxy(
"call_function", _export_tracepoint, p_args, p_kwargs
)
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)


@_export_tracepoint.py_impl(FakeTensorMode)
def export_tracepoint_fake_tensor_mode(*args, **kwargs):
return args
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
with mode:
return args


@_export_tracepoint.py_impl(DispatchKey.Functionalize)
Expand Down
41 changes: 11 additions & 30 deletions torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch.utils._python_dispatch import _get_current_dispatch_mode


@contextmanager
Expand Down Expand Up @@ -281,35 +278,19 @@ def cond_op_dense(pred, true_fn, false_fn, operands):


@cond_op.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
# TODO Move this to proper utility function
from torch._ops import mode_stack_per_key, temporarily_pop_mode

# torch.cond expects ProxyTorchDispatchMode to **still** be on the stack
# at the time that its proxy implementation is called.
# However, the mode can live in one of two places, depending on
# whether we're doing pre_dispatch tracing or normal tracing.
pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, []) # type: ignore[attr-defined]
if len(pre_dispatch_modes) > 0:
with temporarily_pop_mode(pre_dispatch_modes) as mode:
if mode.enable_tracing:
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
else:
return cond_op(pred, true_fn, false_fn, operands)
mode = _get_current_dispatch_mode()
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
else:
return cond_op(pred, true_fn, false_fn, operands)
def inner(mode, pred, true_fn, false_fn, operands):
if mode.enable_tracing:
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
else:
return cond_op(pred, true_fn, false_fn, operands)


@cond_op.py_impl(FakeTensorMode)
def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
true_outs = true_fn(*operands)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
with mode:
true_outs = true_fn(*operands)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
if len(flat_true_outs) != len(flat_false_outs):
raise RuntimeError("Unmatched number of outputs from cond() branches.")

Expand Down
30 changes: 8 additions & 22 deletions torch/_higher_order_ops/out_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
track_tensor_tree,
maybe_handle_decomp,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch._C import DispatchKey, _ExcludeDispatchKeyGuard, DispatchKeySet
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
Expand Down Expand Up @@ -159,36 +155,26 @@ def out_dtype_fallback(op, output_dtype, *args):

@out_dtype.py_impl(ProxyTorchDispatchMode)
def out_dtype_proxy(
mode: ProxyTorchDispatchMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
# TODO Move this to proper utility function
from torch._ops import mode_stack_per_key, temporarily_pop_mode
pre_dispatch_modes = mode_stack_per_key().get(DispatchKey.PreDispatch, []) # type: ignore[attr-defined]
if len(pre_dispatch_modes) > 0:
with temporarily_pop_mode(pre_dispatch_modes) as mode:
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
else:
return out_dtype(op, output_dtype, *args)

mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
else:
return out_dtype(op, output_dtype, *args)
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
else:
return out_dtype(op, output_dtype, *args)


@out_dtype.py_impl(FakeTensorMode)
def out_dtype_fake_tensor_mode(
mode: FakeTensorMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
return out_dtype_dense(op, output_dtype, *args)
with mode:
return out_dtype_dense(op, output_dtype, *args)


@out_dtype.py_impl(DispatchKey.Functionalize)
Expand Down
18 changes: 14 additions & 4 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,23 @@ def dispatch(self, dispatch_key, *args, **kwargs):
return dispatch_functorch(self, args, kwargs)

if dispatch_key == torch._C.DispatchKey.Python:
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
# The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc
from torch.utils._python_dispatch import _pop_mode_temporarily

curr_mode = _get_current_dispatch_mode()
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
assert (
type(curr_mode) in self.python_key_mode_table
), f"Current active mode {curr_mode} not registered"
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
handler = self.python_key_mode_table[type(curr_mode)]
with _pop_mode_temporarily() as mode:
return handler(mode, *args, **kwargs)

functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
if functionality_key in mode_stack_per_key():
# The place to handle DispatchKey.PreDispatch
curr_stack = mode_stack_per_key()[functionality_key]
# The check for Python in the exclude set is so we properly respect `with no_dispatch()`
# calls inside of a mode.
Expand All @@ -265,7 +269,13 @@ def dispatch(self, dispatch_key, *args, **kwargs):
DispatchKey.Python
):
curr_mode = curr_stack[-1]
return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
pre_dispatch_modes = mode_stack_per_key().get(
DispatchKey.PreDispatch, [] # type: ignore[attr-defined]
)
handler = self.python_key_mode_table[type(curr_mode)]
if len(pre_dispatch_modes) > 0:
with temporarily_pop_mode(pre_dispatch_modes) as mode:
return handler(mode, *args, **kwargs)

final_key = resolve_key(self, dispatch_key)

Expand Down
74 changes: 31 additions & 43 deletions torch/_prims/rng_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
track_tensor_tree,
)
from torch.types import _device, _dtype
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)


rngprim_namespace = "rngprims"
Expand Down Expand Up @@ -192,27 +188,23 @@ def impl_backend_select(op, *args, **kwargs):
return impl(op, *args, **kwargs)

@run_and_save_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(op, *args, **kwargs):
def impl_fake_tensor_mode(mode, op, *args, **kwargs):
# Check device to call the right impl
return impl_backend_select(op, *args, **kwargs)
with mode:
return impl_backend_select(op, *args, **kwargs)

@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(op, *args, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is not None
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
out = impl_fake_tensor_mode(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=mode.tracer
)
else:
return run_and_save_rng_state(op, *args, **kwargs)
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
if mode.enable_tracing:
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_and_save_rng_state(op, *args, **kwargs)

return run_and_save_rng_state

Expand Down Expand Up @@ -247,25 +239,20 @@ def impl_cpu(rng_state, op, *args, **kwargs):
return out

@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(rng_state, op, *args, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is not None
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=mode.tracer
)
else:
return run_with_rng_state(rng_state, op, *args, **kwargs)
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
if mode.enable_tracing:
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_with_rng_state(rng_state, op, *args, **kwargs)

@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
Expand All @@ -276,10 +263,11 @@ def impl_backend_select(rng_state, op, *args, **kwargs):
return impl(rng_state, op, *args, **kwargs)

@run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(rng_state, op, *args, **kwargs):
def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
return op(*args, **kwargs)
with mode:
return op(*args, **kwargs)

return run_with_rng_state

Expand Down

0 comments on commit 8a567bb

Please sign in to comment.