New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch cond operator, python dispatch, pyoperator #83154
Conversation
🔗 Helpful links
✅ No Failures (3 Pending)As of commit 670f3d2 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Massive credit to @zou3519's work, this just keeps it moving |
This is kind of nasty but it works. I attempted to fix FX first but the inspect logic is impenetrable. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
Fixes #83251 Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
…proxy tensor to test it" Fixes #83251 Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
…test it" Fixes #83251 Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
With a sample usage in proxy tensor to show how they can shorten your code dramatically. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
# We could get an out tensor by running the real ops | ||
# But to avoid running real code in tracing, we just use a dummy tensor. | ||
# if pred: | ||
# out = true_fn(*operands) | ||
# else: | ||
# out = false_fn(*operands) | ||
out = torch.zeros([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is going to fly, if we're doing real tracing (make_fx(f, mode='real')
), it is supposed to end up executing real operations on the Tensor.
If we're tracing a function that calls cond and then passes the output of cond to another pytorch operation (like torch.sum), then torch.sum
is going to receive this dummy Tensor and then will actually execute out.sum()
. This will be a problem if the other operation is something that does rely on the output to be a real Tensor.
Here's an hacky idea that might work:
- create
new_true_fn
.new_true_fn
executestrue_fn
and then stores the output oftrue_fn
somewhere.
saved_true_out = []
def new_true_fn(*operands):
out = true_fn(operands)
saved_out.append(out)
return out
- Do
true_graph = get_isolated_graphmodule(new_true_fn, operands, {})
(and ditto for new_false_fn). - Now, instead of
out = torch.zeros([])
, we can return eitherout = saved_true_out[0]
orout = saved_false_out[0]
. - Furthermore, we can compare
saved_true_out
andsaved_false_out
for if "all the properties match" (though maybe that is what your checks above are doing, I'm not familiar with the code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might as well just run the real thing without all the extra pizzaz around new_true_fn
though. Just uncomment my commented code.
Will change.
Furthermore, we can compare saved_true_out and saved_false_out for if "all the properties match" (though maybe that is what your checks above are doing, I'm not familiar with the code).
We already do that above, yes.
for i in range(0, len(flat_true_outs)): | ||
true_out = flat_true_outs[i] | ||
false_out = flat_false_outs[i] | ||
assert(true_out.meta == false_out.meta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for now, but I'm not sure it's comprehensive. E.g. TensorMetadata doesn't seem to have layout (torch.strided, torch.sparse_coo).
Also nit: don't use parenthesis while calling assert: assert true_out.meta == false_out.meta
. The reason is that it is easy to accidentally turn the parenthesis into a tuple (e.g. assert (true_out.meta == false_out.meta,)
and assert on a tuple checks that it has at least one element
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meta is {'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
for now - we can always add more to meta, right?
One needs to add empty |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Merge failed |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @voznesenskym. |
Summary: Fixes #ISSUE_NUMBER X-link: pytorch/pytorch#83154 Approved by: https://github.com/ezyang Reviewed By: weiwangmeta Differential Revision: D39034501 Pulled By: voznesenskym fbshipit-source-id: 7be6caa7e3c7345f50671a7cef8a5fd0b2565a21
Summary: Fixes #ISSUE_NUMBER Pull Request resolved: #83154 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ced2ca8f867b376c5b4e495f183aeba78a27c0c4 Reviewed By: weiwangmeta Differential Revision: D39034501 Pulled By: voznesenskym fbshipit-source-id: 7be6caa7e3c7345f50671a7cef8a5fd0b2565a21
Fixes #ISSUE_NUMBER