Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autograd] Cond Higher-Order Operation #126911

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cdc9081
Changes required to be able to build PyTorch in conda environment wit…
bohnstingl May 9, 2024
38f8ae3
Merge branch 'main' of github.com:pytorch/pytorch into cond_autograd
bohnstingl May 9, 2024
856abe5
WIP: Cond Autograd
bohnstingl May 9, 2024
4ff9e91
WIP: cond autograd forward
bohnstingl May 10, 2024
e3245b4
Fixed cond autograd
bohnstingl May 21, 2024
0729ac5
Executed 'test_control_flow.py' and 'test_higher_order_ops.py'
bohnstingl May 22, 2024
8a70403
Added additional testcases
bohnstingl May 22, 2024
e83e375
Merge branch 'main' of github.com:pytorch/pytorch into cond_autograd
bohnstingl May 24, 2024
f6cfa96
Fixed import issues
bohnstingl May 24, 2024
9617494
Fixed test_cond_vmap_single_input_with_closure testcase
bohnstingl May 24, 2024
cf6f1b1
Introduced lower level function for create_fw_bw_graph
bohnstingl May 25, 2024
40d584c
Updated low-level create_fw_bw_graph function (compatible with while_…
bohnstingl May 25, 2024
6ea5d04
Cleanup
bohnstingl May 25, 2024
34ac604
Added some more test cases with loops
bohnstingl May 26, 2024
908f7f8
Cleaned utils.py
bohnstingl May 27, 2024
25c9262
Fixed lintrunner and some of the cond testcases
bohnstingl May 30, 2024
8e2203f
Updated low-level create_fw_bw_graph function
bohnstingl Jun 1, 2024
630549a
Updated low-level create_fw_bw_graph function
bohnstingl Jun 3, 2024
d2d77eb
Merge branch 'main' of github.com:pytorch/pytorch into cond_autograd
bohnstingl Jul 8, 2024
0af022b
Merge branch 'main' of github.com:pytorch/pytorch into cond_autograd
bohnstingl Jul 9, 2024
016e1d6
Merge branch 'pytorch:main' into cond_autograd
bohnstingl Jul 9, 2024
9d4b4c4
Updated testcases for cond
bohnstingl Jul 9, 2024
a547f55
Merge branch 'cond_autograd' of github.com:bohnstingl/pytorch into co…
bohnstingl Jul 9, 2024
9ae1e77
cond autograd
ydwu4 Jul 12, 2024
9e1388d
Integrated additions from https://github.com/pytorch/pytorch/pull/130607
bohnstingl Jul 12, 2024
e50f0cf
Updated cond.py
bohnstingl Jul 12, 2024
70b2ed2
Updated cpuinfo to most recent version
bohnstingl Jul 12, 2024
15e3ddd
Reverted cpuinfo version
bohnstingl Jul 12, 2024
e3e5636
Updated third_party modules to pytorch:main version
bohnstingl Jul 12, 2024
5698b41
Merge branch 'cond_autograd' of github.com:bohnstingl/pytorch into HEAD
bohnstingl Jul 12, 2024
aa3ef8c
Merge branch 'gh/ydwu4/132/base' of github.com:pytorch/pytorch into c…
bohnstingl Jul 12, 2024
969e79e
Merge branch 'cond_autograd' of github.com:bohnstingl/pytorch into co…
bohnstingl Jul 12, 2024
ddccf10
Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 12, 2024
6458271
Update base for Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
4c1567a
Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
f8f697c
Update base for Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
78778ad
Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
3313990
Update base for Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
d3d3254
Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 13, 2024
14b3582
Merge branch 'gh/ydwu4/132/head' of github.com:pytorch/pytorch into c…
bohnstingl Jul 13, 2024
a1e98f9
Merge branch 'gh/ydwu4/132/base' of github.com:pytorch/pytorch into c…
bohnstingl Jul 13, 2024
8d03c22
Fixed some more testcases
bohnstingl Jul 14, 2024
d812140
Update base for Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 15, 2024
e0169c6
Update on "[NOT FOR REVIEW] cond autograd."
ydwu4 Jul 15, 2024
f1e1412
Merge branch 'gh/ydwu4/132/base' of github.com:pytorch/pytorch into c…
bohnstingl Jul 15, 2024
2456978
Merge branch 'gh/ydwu4/132/head' of github.com:pytorch/pytorch into c…
bohnstingl Jul 15, 2024
7328e71
Final cleanup
bohnstingl Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion functorch/experimental/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
_stack_pytree,
_unstack_pytree,
map,
)
)
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
417 changes: 391 additions & 26 deletions test/functorch/test_control_flow.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,7 +1912,6 @@ def fn(model: Callable):


from torch import export as export
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should revert the change to this file.


from torch._higher_order_ops import cond

def _register_device_module(device_type, module):
Expand Down
145 changes: 137 additions & 8 deletions torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils._pytree as pytree

from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._C._functorch import (
_add_batch_dim,
get_unwrapped,
Expand All @@ -19,22 +20,27 @@
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
unique_graph_id,
UnsupportedAliasMutationException,
)

from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import (
disable_functional_mode,
)
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
_temp_remove_pre_dispatch_torch_function_mode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import _get_current_dispatch_mode

from torch.fx.experimental.proxy_tensor import _temp_remove_pre_dispatch_torch_function_mode
from .utils import _from_fun, clone_outputs_aliasing_inputs, prepare_fw_with_masks

@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
Expand Down Expand Up @@ -101,8 +107,6 @@ def false_fn(x: torch.Tensor):
.. warning::
Temporal Limitations:

- `cond` only supports **inference** right now. Autograd will be supported in the future.

- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.

"""
Expand Down Expand Up @@ -142,12 +146,111 @@ def _validate_input(pred, true_fn, false_fn, operands):
pred, true_fn, false_fn, operands
)


"""
We're going to define a `cond_op` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
cond_op = HigherOrderOperator("cond")
cond_op.__module__ = "torch.ops.higher_order"

def create_fw_bw_graph(true_fn, false_fn, *operands):
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved

from torch._functorch.aot_autograd import AOTConfig, create_joint
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)

# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.

with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():

num_mapped_args = len(operands)
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
unwrapped_mapped_operands = pytree.tree_map(_from_fun, operands)
example_operands = unwrapped_mapped_operands

#Note, the true_fn and the false_fn produce the same output
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
#shape, thus we can simply generate the example outputs from the true_fn.
example_flat_out = pytree.tree_map(
_from_fun, true_fn(*example_operands)
)
if any(
not isinstance(out, torch.Tensor)
for out in example_flat_out
if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}."
)
example_grad = [_from_fun(out) for out in example_flat_out]

fw_true_graph = make_fx(true_fn)(*example_operands)
fw_false_graph = make_fx(false_fn)(*example_operands)

def joint_f_true(*joint_mapped_args):
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]

joint = create_joint(prepare_fw_with_masks(true_fn), aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)

# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_mapped_args)

return pytree.tree_map(maybe_clone, grads)

def joint_f_false(*joint_mapped_args):
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]

joint = create_joint(prepare_fw_with_masks(false_fn), aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)

# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_mapped_args)

return pytree.tree_map(maybe_clone, grads)

joint_operands_grads = list(example_operands) + list(example_grad)
joint_true_graph = make_fx(joint_f_true)(*joint_operands_grads)
joint_false_graph = make_fx(joint_f_false)(*joint_operands_grads)
return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph


def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
Expand Down Expand Up @@ -243,10 +346,36 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
return false_fn(*operands)


cond_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(cond_op, deferred_error=True)
)
class CondAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, pred, fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph, num_mapped_args, *operands):
ctx._pred = pred
ctx._joint_true_graph = joint_true_graph
ctx._joint_false_graph = joint_false_graph
ctx.save_for_backward(*operands)

with torch._C._AutoDispatchBelowAutograd():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(cond_op, backend="eager", fullgraph=True)(
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
pred, fw_true_graph, fw_false_graph, operands
)

@staticmethod
def backward(ctx, *flat_grads):
operands = ctx.saved_tensors

with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
grads = torch.compile(cond_op, backend="eager", fullgraph=True)(
ctx._pred, ctx._joint_true_graph, ctx._joint_false_graph, operands + flat_grads
)
return None, None, None, None, None, None, *grads

@cond_op.py_impl(DispatchKey.Autograd)
def cond_autograd(pred, true_fn, false_fn, operands):
num_mapped_args = len(operands)
bohnstingl marked this conversation as resolved.
Show resolved Hide resolved
fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph = create_fw_bw_graph(true_fn, false_fn, *operands)
flat_out = CondAutogradOp.apply(pred, fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph, num_mapped_args, *operands)
return flat_out

@cond_op.py_impl(ProxyTorchDispatchMode)
def inner(mode, pred, true_fn, false_fn, operands):
Expand Down
96 changes: 4 additions & 92 deletions torch/_higher_order_ops/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
from torch._functorch.aot_autograd import AOTConfig, create_joint

from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
Expand All @@ -14,15 +14,14 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import (
disable_functional_mode,
FunctionalTensor,
)
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .utils import _from_fun, clone_outputs_aliasing_inputs, prepare_fw_with_masks, _unstack_pytree, _stack_pytree


# TODO: We add this to prevent dymamo from tracing into map_wrapper,
Expand Down Expand Up @@ -68,31 +67,6 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():

def _from_fun(t):
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t

unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]

Expand Down Expand Up @@ -123,16 +97,7 @@ def joint_f(*example_args):
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]

def fw_with_masks(*args):
fw_out = f(*args)
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]

joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input) + list(args),
[
Expand All @@ -144,19 +109,7 @@ def fw_with_masks(*args):

# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in example_args
if isinstance(arg, torch.Tensor)
}

def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
maybe_clone = clone_outputs_aliasing_inputs(example_args)

return pytree.tree_map(maybe_clone, grads)

Expand Down Expand Up @@ -255,47 +208,6 @@ def expand_tensor(t):
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
)


def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")

if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)

a = zip(*flat_xs)

pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees


def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)


@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args):
pytrees = []
Expand Down
Loading
Loading