-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Autograd] Cond Higher-Order Operation #126911
base: main
Are you sure you want to change the base?
Conversation
Please also do a rebase and address the conflicts. |
Merged with main Pytorch main branch Added new testcases Consolidated common functions for the create_fw_bw_graph with map.py
@ydwu4 Thank you very much for your comments and reviews. I updated the files accordingly. There are two open points though: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work overall! We're close to land the feature. Left some comments on the pr. Mostly about:
- more tests, see the comments for some ideas of testing. Could use your imagination to create more tests.
- we should not call torch.compile again when we're dispatching autograd key. We should just call the operator itself.
torch/_higher_order_ops/cond.py
Outdated
|
||
return pytree.tree_map(maybe_clone, grads) | ||
|
||
def joint_f_false(*joint_mapped_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The redundancy between joinf_f_true and joint_f_false can be removed if we refactor create_fw_bw_graph to take a single function and its operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, the redundancy has been alleviated by utilizing the new low-level API. However, there is a problem with this new API, see comment above.
Fixed test cases Removed unneccessary compil step in AutoGrad
…loop) Fixed issues with cond Added several additional testcases
I tried to incorporate the review comments from above. While creating additional testcases, I stumbled over a couple of problems:
|
return torch.autograd.grad(x_new, (x,), grad_out) | ||
|
||
# TODO: During compilation, the metadata of the true_fn has the | ||
# requires_grad attribute set to False, while the false_fn has it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unexpected... Can we find out why the requires_grad of true_fn is set to False? true_fn should return an output that requires_grad since its input x hasrequries_grad = True
. If not, there could be a bug somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good catch @ydwu4, I think this might even be the root cause of some of the other issues that we see. In this particular case, the make_fx
is introducing this detach operation even in the forward path. So in case of
def forward(self, l_args_3_0_):
l_args_3_0__1 = l_args_3_0_
relu = torch.nn.functional.relu(l_args_3_0__1, inplace = False); l_args_3_0__1 = None
return (relu,)
make_fx
turns this into
def forward(self, arg0_1):
relu = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
detach = torch.ops.aten.detach.default(relu)
return (relu,)
I don't quite know what is the reason for this and I am not familiar with the inner workings of make_fx
This is the same for torch.nn.GRUCell
. However, this does not happen if I just create a custom nn.Module that itself implements the ReLU function
class SimpleNN(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.maximum(torch.zeros_like(x), x)
nn_module_false = SimpleNN()
def false_fn(x):
return nn_module_false(x)
which when make_fx is applied gives
def forward(self, arg0_1):
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
maximum = torch.ops.aten.maximum.default(zeros_like, arg0_1); zeros_like = arg0_1 = None
return (maximum,)
Do you have any hints as to how to approach this?
I just checked and this behavior is also true for the main branch of PT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I can take a look. It seems related to whether the inputs requires_grad. For a normal function like below,
import torch
from torch.fx.experimental.proxy_tensor import make_fx
def f(x):
return torch.nn.functional.relu(x)
gm = make_fx(f)(torch.randn(3, 2))
we'll get a normal graph as shown below when input doesn't requires gradient.
class f(torch.nn.Module):
def forward(self, x_1: "f32[3, 2]"):
# No stacktrace found for following nodes
relu: "f32[3, 2]" = torch.ops.aten.relu.default(x_1); x_1 = None
return relu
For cond,
import torch
from torch.fx.experimental.proxy_tensor import make_fx
def f(x):
return torch.cond(x.sum() > 0, lambda x: torch.nn.functional.relu(x), lambda x: x.sin(), (x,))
gm = make_fx(f)(torch.randn(3, 2))
class f(torch.nn.Module):
def forward(self, x_1: "f32[3, 2]"):
# No stacktrace found for following nodes
sum_1: "f32[]" = torch.ops.aten.sum.default(x_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]); gt = true_graph_0 = false_graph_0 = x_1 = None
getitem: "f32[3, 2]" = conditional[0]; conditional = None
return getitem
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
# No stacktrace found for following nodes
relu: "f32[3, 2]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
return (relu,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
# No stacktrace found for following nodes
sin: "f32[3, 2]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return (sin,)
The graph looks like above. But setting requires_grad = True will cause the detach to appear. Let me check what's going on..
It's great that you come up with a good test that helps us identify this earlier hahah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for looking into this!
I will continue to think about more such creative usecases, that potentially unveils bugs :-)
Some information that might help you in this issue: I tried to test several components of torch.nn
and it seems that this issues is not present for all of them. For example, the detach operations are not
present for
torch.nn.Identity()
torch.nn.Dropout()
torch.nn.Linear()
torch.nn.functional.threshold
torch.nn.functional.glu
torch.nn.functional.gelu
torch.nn.functional.leaky_relu
torch.nn.functional.logsigmoid
torch.nn.functional.softplus
torch.nn.functional.hardtanh
torch.nn.functional.elu
torch.nn.functional.selu
torch.nn.functional.rrelu
but, are
present for
torch.nn.RNN
torch.nn.LSTM
torch.nn.GRU
torch.nn.Transformer
torch.nn.functional.relu
torch.nn.functional.softmax
torch.nn.functional.tanh
torch.nn.functional.softmin
At first, there was no clear pattern to me, but after quite some investigations into the failing cases, it seems that this behavior is mostly triggered by the tanh
, relu
or softmax
functions. At least, all the failing cases I could track back to one of these functions. I hope this helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the thorough study! @zou3519 told me when autograd saves a Tensor for backwards, there's a call to .detach() and pytorch has auto codegen for autograd: see derivatives.yaml
. Generated code is in VariableTypeEverything.cpp
. In the case of relu
, the result is saved like this: grad_fn->result_ = SavedVariable(result, true);
So the conclusion is that this is correct behavior. But that doesn't solve our pytree output error though. Might need to look into it further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Not sure about the detail. Checked leaky_relu a bit: the difference between relu and leaky_relu in the generated code is that it doesn't store the SavedVariable in grad_fn like relu did. However, from derivatives.yarml, I cannot tell why they're different. They all return auto_element_wise
. Might have other places of controlling the code-gen logic. This is a separate discussion though (or is this necessary for getting correct pytree output?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, thank you.
I think for the moment, we could just leave a TODO, but we need to come back eventually to this, I guess. I also observed that a similar (potentially related) issue with the requires_grad
flag occurs for the testcase test_cond_in_forloop, for example. In that case, there is no relu
function involved, but still the true_out.meta["tensor_meta"]
has requires_grad=False
, while true_out.meta["tensor_meta"]
has set it to True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check how tensor_meta
is created and why it is set to False in one branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to go down the rabbit hole and actually found something puzzling to me. In the simple case,
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
The metadata are
true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
However, if in one branch the input arguments are just passed through, e.g.,
def true_fn(x):
return x
def false_fn(x):
return x.cos()
the metadata become
true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
If one applies even a simple multiplication with 1 to the input argument, the metadata match again
def true_fn(x):
return x*1
def false_fn(x):
return x.cos()
true_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
false_out.meta["tensor_meta"] = TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}))
In all of the above cases, the code of the graphs looks as expected
and the two graphs of true_fn and false_fn look as expected
true_graph._code = 'def forward(self, arg0_1):\n return (arg0_1,)'
false_graph._code = 'def forward(self, arg0_1):\n cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None\n return (cos,)'
At the moment, this appears to be the culprit, but I don't know where this comes from and less so how to resolve it.
Addition: I figured out, that the cos node sets the requires_grad
flag in the metadata to False. For example, here are the nodes of the false_fn
0:
(arg0_1, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride...us_format, is_quantized=False, qparams={})})
1:
(cos, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, strid...us_format, is_quantized=False, qparams={})})
2:
(output, {})
I can see that during the graph construction / execution, the args
of the call_function
node cos has still the requires_grad Flag set to True. E.g, if one prints (node.target, node.args[0].meta)
(<OpOverload(op='aten.cos', overload='default')>, {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=True, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})})
but the result of the cos node, so in this case, the output node, has the requires_grad Flag set to False.
('output', {'val': FakeTensor(..., size=(4,)), 'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydwu4. Do you maybe have any insights into this? Thank you
return torch.autograd.grad(x_new, (x,), grad_out) | ||
|
||
# TODO: During compilation, the metadata of the true_fn has the | ||
# requires_grad attribute set to False, while the false_fn has it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should fix this.
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) | ||
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]})) | ||
|
||
# TODO: Gradient computation for such complex pytree output does not work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unexpected.
By the time we enter autograd dispatch key, true_fn and false_fn should already become two graph modules that take a tuple as inputs and produce a tuple as output. All the complicated pytree types should already be flattened (the torch.compile line in torch.cond achieves this). As long as we can handle the tuple type, these tests should be able to pass. We should probably figure out why and fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I am observing is that the following errors out:
def false_fn(x):
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
result_exp = fn({"t": [a, {"b": b}, (c,)]})
grad_out = torch.ones_like(a)
expected_grads = torch.autograd.grad(result_exp, (a,), grad_out)
with the Exception:
Traceback (most recent call last):
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
method(*args, **kwargs)
File "/data_malta3_ssd/pytorch/test/functorch/test_control_flow.py", line 925, in test_cond_autograd_same_pytree_output
expected_grads = torch.autograd.grad(result_exp, (a,), grad_out)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 402, in grad
grad_outputs_ = _make_grads(
^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/autograd/__init__.py", line 91, in _make_grads
if out.is_nested or first_grad.is_nested:
^^^^^^^^^^^^^
AttributeError: 'str' object has no attribute 'is_nested'
So I don't think this has anything to do with cond, it just doesn't work in this setting even for regular autograd. The respective testcase that fails is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah, I forgot to mention that we need to flatten the result_exp and grad_out to make torch.autograd.grad
work doc is here.
@@ -183,7 +226,16 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): | |||
for i in range(0, len(flat_true_outs)): | |||
true_out = flat_true_outs[i] | |||
false_out = flat_false_outs[i] | |||
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: | |||
# TODO: If a torch nn module such as Linear or GRUCell is used, then the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this TODO still valid?
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: | ||
# TODO: If a torch nn module such as Linear or GRUCell is used, then the | ||
# meta data of the output is None and cannot be compared | ||
# TODO: If inside the dictionary, inside the list, the first element |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should figure out why this happens and fix it.
Fixed and added further testcases
This is an updated PR to equip cond with the autograd feature and replaces the old PR
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
torch/__init__.py
, see here. Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.hop_db.py
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang