Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autograd] Cond Higher-Order Operation #126911

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

bohnstingl
Copy link
Contributor

@bohnstingl bohnstingl commented May 22, 2024

This is an updated PR to equip cond with the autograd feature and replaces the old PR

@ydwu4 I tried to incorporate your requests already.

Currently there are two problems that I struggle with solving:

  1. There seems to be an import issue when trying to import cond in torch/__init__.py, see here. Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.
  2. I am not entirely sure how to deal with the opinfo test in hop_db.py

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

Copy link

pytorch-bot bot commented May 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126911

Note: Links to docs will display an error until the docs builds have been completed.

❌ 29 New Failures, 2 Unrelated Failures

As of commit 630549a with merge base 47c976b (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl bohnstingl marked this pull request as ready for review May 23, 2024 06:10
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 23, 2024
functorch/experimental/control_flow.py Outdated Show resolved Hide resolved
test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
test/functorch/test_control_flow.py Show resolved Hide resolved
test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
torch/_higher_order_ops/cond.py Outdated Show resolved Hide resolved
@ydwu4
Copy link
Contributor

ydwu4 commented May 23, 2024

Please also do a rebase and address the conflicts.

Merged with main Pytorch main branch
Added new testcases
Consolidated common functions for the create_fw_bw_graph with map.py
@bohnstingl
Copy link
Contributor Author

@ydwu4 Thank you very much for your comments and reviews. I updated the files accordingly. There are two open points though:

  • One test is failing now after the rebase and I can't quite figure out why.
  • I had to comment out these lines in order to make the test work again. Is this acceptable, or how can one resolve this issue?

Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work overall! We're close to land the feature. Left some comments on the pr. Mostly about:

  1. more tests, see the comments for some ideas of testing. Could use your imagination to create more tests.
  2. we should not call torch.compile again when we're dispatching autograd key. We should just call the operator itself.

test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
test/functorch/test_control_flow.py Show resolved Hide resolved
test/functorch/test_control_flow.py Show resolved Hide resolved
test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
torch/_higher_order_ops/cond.py Outdated Show resolved Hide resolved
torch/_higher_order_ops/cond.py Outdated Show resolved Hide resolved

return pytree.tree_map(maybe_clone, grads)

def joint_f_false(*joint_mapped_args):
Copy link
Contributor

@ydwu4 ydwu4 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The redundancy between joinf_f_true and joint_f_false can be removed if we refactor create_fw_bw_graph to take a single function and its operands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the redundancy has been alleviated by utilizing the new low-level API. However, there is a problem with this new API, see comment above.

test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
test/functorch/test_control_flow.py Show resolved Hide resolved
functorch/experimental/control_flow.py Show resolved Hide resolved
Fixed test cases
Removed unneccessary compil step in AutoGrad
…loop)

Fixed issues with cond
Added several additional testcases
@bohnstingl
Copy link
Contributor Author

I tried to incorporate the review comments from above. While creating additional testcases, I stumbled over a couple of problems:

  1. As mentioned above, the create_fw_bw_graph function is invoked twice, which is somewhat problematic.
  2. Some nn.Modules caused problems when used in the true_fn/false_fn. For example, when the nn.Linear or the nn.GRU module is used, the check in the trace fails, as the true_out or the false_out in that case do not have meta data. Therefore, I extended this check to avoid this issue.
  3. The testcase TestControlFlowTraced.test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph is failing and I don't quite know why. The respective dictionary does not track all the meta data correctly. For example, the target output is [('cond', <torch._ops.HigherOrderOperator object at 0x7f5b4d710810>), ('cos', 'cos')], while the produced output is [('cond', <torch._ops.HigherOrderOperator object at 0x7f5b4d710810>)]. Here is the stack trace:
======================================================================
FAIL: test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph (__main__.TestControlFlowTraced.test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 3413, in test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph
    self.assertEqual(all_source_fns, new_source_fns)
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3607, in assertEqual
    error_metas = not_close_error_metas(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/testing/_comparison.py", line 1220, in not_close_error_metas
    raise error_meta.to_error() from None  # noqa: RSE102
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: The length of the sequences mismatch: 2 != 1

The failure occurred for item [0]

To execute this test, run the following from the base repo dir:
     python test/functorch/test_control_flow.py -k TestControlFlowTraced.test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph

@bohnstingl bohnstingl requested a review from ydwu4 May 26, 2024 07:46
return torch.autograd.grad(x_new, (x,), grad_out)

# TODO: During compilation, the metadata of the true_fn has the
# requires_grad attribute set to False, while the false_fn has it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unexpected... Can we find out why the requires_grad of true_fn is set to False? true_fn should return an output that requires_grad since its input x hasrequries_grad = True. If not, there could be a bug somewhere.

Copy link
Contributor Author

@bohnstingl bohnstingl May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good catch @ydwu4, I think this might even be the root cause of some of the other issues that we see. In this particular case, the make_fx is introducing this detach operation even in the forward path. So in case of

def forward(self, l_args_3_0_):
    l_args_3_0__1 = l_args_3_0_
    relu = torch.nn.functional.relu(l_args_3_0__1, inplace = False);  l_args_3_0__1 = None
    return (relu,)

make_fx turns this into

def forward(self, arg0_1):
    relu = torch.ops.aten.relu.default(arg0_1);  arg0_1 = None
    detach = torch.ops.aten.detach.default(relu)
    return (relu,)

I don't quite know what is the reason for this and I am not familiar with the inner workings of make_fx
This is the same for torch.nn.GRUCell. However, this does not happen if I just create a custom nn.Module that itself implements the ReLU function

class SimpleNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.maximum(torch.zeros_like(x), x)
nn_module_false = SimpleNN()

def false_fn(x):
    return nn_module_false(x)

which when make_fx is applied gives

def forward(self, arg0_1):
    zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
    maximum = torch.ops.aten.maximum.default(zeros_like, arg0_1);  zeros_like = arg0_1 = None
    return (maximum,)

Do you have any hints as to how to approach this?
I just checked and this behavior is also true for the main branch of PT.

Copy link
Contributor

@ydwu4 ydwu4 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I can take a look. It seems related to whether the inputs requires_grad. For a normal function like below,

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(x):
    return torch.nn.functional.relu(x)

gm = make_fx(f)(torch.randn(3, 2))

we'll get a normal graph as shown below when input doesn't requires gradient.

class f(torch.nn.Module):
    def forward(self, x_1: "f32[3, 2]"):
        # No stacktrace found for following nodes
        relu: "f32[3, 2]" = torch.ops.aten.relu.default(x_1);  x_1 = None
        return relu

For cond,

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(x):
    return torch.cond(x.sum() > 0, lambda x: torch.nn.functional.relu(x), lambda x: x.sin(), (x,))

gm = make_fx(f)(torch.randn(3, 2))
class f(torch.nn.Module):
    def forward(self, x_1: "f32[3, 2]"):
        # No stacktrace found for following nodes
        sum_1: "f32[]" = torch.ops.aten.sum.default(x_1)
        gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
        true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]);  gt = true_graph_0 = false_graph_0 = x_1 = None
        getitem: "f32[3, 2]" = conditional[0];  conditional = None
        return getitem
        
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
            # No stacktrace found for following nodes
            relu: "f32[3, 2]" = torch.ops.aten.relu.default(arg0_1);  arg0_1 = None
            return (relu,)
            
    class <lambda>(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
            # No stacktrace found for following nodes
            sin: "f32[3, 2]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
            return (sin,)

The graph looks like above. But setting requires_grad = True will cause the detach to appear. Let me check what's going on..

It's great that you come up with a good test that helps us identify this earlier hahah.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for looking into this!
I will continue to think about more such creative usecases, that potentially unveils bugs :-)

Some information that might help you in this issue: I tried to test several components of torch.nn and it seems that this issues is not present for all of them. For example, the detach operations are not present for

  • torch.nn.Identity()
  • torch.nn.Dropout()
  • torch.nn.Linear()
  • torch.nn.functional.threshold
  • torch.nn.functional.glu
  • torch.nn.functional.gelu
  • torch.nn.functional.leaky_relu
  • torch.nn.functional.logsigmoid
  • torch.nn.functional.softplus
  • torch.nn.functional.hardtanh
  • torch.nn.functional.elu
  • torch.nn.functional.selu
  • torch.nn.functional.rrelu

but, are present for

  • torch.nn.RNN
  • torch.nn.LSTM
  • torch.nn.GRU
  • torch.nn.Transformer
  • torch.nn.functional.relu
  • torch.nn.functional.softmax
  • torch.nn.functional.tanh
  • torch.nn.functional.softmin

At first, there was no clear pattern to me, but after quite some investigations into the failing cases, it seems that this behavior is mostly triggered by the tanh, relu or softmax functions. At least, all the failing cases I could track back to one of these functions. I hope this helps.

Copy link
Contributor

@ydwu4 ydwu4 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough study! @zou3519 told me when autograd saves a Tensor for backwards, there's a call to .detach() and pytorch has auto codegen for autograd: see derivatives.yaml. Generated code is in VariableTypeEverything.cpp. In the case of relu, the result is saved like this: grad_fn->result_ = SavedVariable(result, true);

So the conclusion is that this is correct behavior. But that doesn't solve our pytree output error though. Might need to look into it further.

Copy link
Contributor

@ydwu4 ydwu4 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Not sure about the detail. Checked leaky_relu a bit: the difference between relu and leaky_relu in the generated code is that it doesn't store the SavedVariable in grad_fn like relu did. However, from derivatives.yarml, I cannot tell why they're different. They all return auto_element_wise. Might have other places of controlling the code-gen logic. This is a separate discussion though (or is this necessary for getting correct pytree output?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thank you.

I think for the moment, we could just leave a TODO, but we need to come back eventually to this, I guess. I also observed that a similar (potentially related) issue with the requires_grad flag occurs for the testcase test_cond_in_forloop, for example. In that case, there is no relu function involved, but still the true_out.meta["tensor_meta"] has requires_grad=False, while true_out.meta["tensor_meta"] has set it to True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check how tensor_meta is created and why it is set to False in one branch?

Copy link
Contributor Author

@bohnstingl bohnstingl Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to go down the rabbit hole and actually found something puzzling to me. In the simple case,

def true_fn(x):
    return x.sin()
def false_fn(x):
    return x.cos()

The metadata are

true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))

However, if in one branch the input arguments are just passed through, e.g.,

def true_fn(x):
    return x
def false_fn(x):
    return x.cos()

the metadata become

true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))

If one applies even a simple multiplication with 1 to the input argument, the metadata match again

def true_fn(x):
    return x*1
def false_fn(x):
    return x.cos()
true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))

In all of the above cases, the code of the graphs looks as expected
and the two graphs of true_fn and false_fn look as expected

true_graph._code = 'def forward(self, arg0_1):\n    return (arg0_1,)'
false_graph._code = 'def forward(self, arg0_1):\n    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None\n    return (cos,)'

At the moment, this appears to be the culprit, but I don't know where this comes from and less so how to resolve it.

Addition: I figured out, that the cos node sets the requires_grad flag in the metadata to False. For example, here are the nodes of the false_fn

0:
(arg0_1, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride...us_format, is_quantized=False, qparams={})})
1:
(cos, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, strid...us_format, is_quantized=False, qparams={})})
2:
(output, {})

I can see that during the graph construction / execution, the args of the call_function node cos has still the requires_grad Flag set to True. E.g, if one prints (node.target, node.args[0].meta)

(<OpOverload(op='aten.cos', overload='default')>, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})})

but the result of the cos node, so in this case, the output node, has the requires_grad Flag set to False.

('output', {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydwu4. Do you maybe have any insights into this? Thank you

return torch.autograd.grad(x_new, (x,), grad_out)

# TODO: During compilation, the metadata of the true_fn has the
# requires_grad attribute set to False, while the false_fn has it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should fix this.

result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))

# TODO: Gradient computation for such complex pytree output does not work
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unexpected.

By the time we enter autograd dispatch key, true_fn and false_fn should already become two graph modules that take a tuple as inputs and produce a tuple as output. All the complicated pytree types should already be flattened (the torch.compile line in torch.cond achieves this). As long as we can handle the tuple type, these tests should be able to pass. We should probably figure out why and fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I am observing is that the following errors out:

def false_fn(x):
    return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}

result_exp = fn({"t": [a, {"b": b}, (c,)]})
grad_out = torch.ones_like(a)
expected_grads = torch.autograd.grad(result_exp, (a,), grad_out)

with the Exception:

Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 925, in test_cond_autograd_same_pytree_output
    expected_grads = torch.autograd.grad(result_exp, (a,), grad_out)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 402, in grad
    grad_outputs_ = _make_grads(
                    ^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 91, in _make_grads
    if out.is_nested or first_grad.is_nested:
       ^^^^^^^^^^^^^
AttributeError: 'str' object has no attribute 'is_nested'

So I don't think this has anything to do with cond, it just doesn't work in this setting even for regular autograd. The respective testcase that fails is here.

Copy link
Contributor

@ydwu4 ydwu4 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, I forgot to mention that we need to flatten the result_exp and grad_out to make torch.autograd.grad work doc is here.

test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
test/functorch/test_control_flow.py Outdated Show resolved Hide resolved
torch/_higher_order_ops/cond.py Show resolved Hide resolved
@@ -183,7 +226,16 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
# TODO: If a torch nn module such as Linear or GRUCell is used, then the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO still valid?

if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
# TODO: If a torch nn module such as Linear or GRUCell is used, then the
# meta data of the output is None and cannot be compared
# TODO: If inside the dictionary, inside the list, the first element
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should figure out why this happens and fix it.

@ezyang ezyang removed their request for review June 5, 2024 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants