Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch cond operator, python dispatch, pyoperator #83154

Closed
wants to merge 86 commits into from

Conversation

voznesenskym
Copy link
Collaborator

Fixes #ISSUE_NUMBER

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 10, 2022

🔗 Helpful links

✅ No Failures (3 Pending)

As of commit 670f3d2 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@voznesenskym
Copy link
Collaborator Author

Massive credit to @zou3519's work, this just keeps it moving

@voznesenskym voznesenskym changed the title [WIP] [RFC] [Not ready for review yet] Torch cond operator [WIP] [RFC] [Not ready for review yet] Torch cond operator, python dispatch, etc etc Aug 10, 2022
torch/fx/proxy.py Outdated Show resolved Hide resolved
voznesenskym and others added 3 commits August 10, 2022 22:23
This is kind of nasty but it works.  I attempted to fix FX
first but the inspect logic is impenetrable.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@voznesenskym
Copy link
Collaborator Author

voznesenskym commented Aug 11, 2022

Currently not working. pred is captured.
image

Which means that:

result_true = graph.forward(x, torch.tensor(True))
print("True:", result_true)
result_false = graph.forward(x, torch.tensor(False))
print("False:", result_false)

Have identical results. Need to figure this out. true_fn and false_fn parts are traced fine and work, changing the bool tensor passed to make_fx changes the result, when we really should be capturing it symbolically.

Need to figure out why those inputs arent going to conditional, but some tensor consts are getting created in their stead.

torch/fx/experimental/proxy_tensor.py Outdated Show resolved Hide resolved
control_flow.py Outdated Show resolved Hide resolved
voznesenskym and others added 7 commits August 11, 2022 21:04
Fixes #83251

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
…proxy tensor to test it"

Fixes #83251

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
…test it"

Fixes #83251

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
With a sample usage in proxy tensor to show how they can shorten
your code dramatically.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
Comment on lines 93 to 99
# We could get an out tensor by running the real ops
# But to avoid running real code in tracing, we just use a dummy tensor.
# if pred:
# out = true_fn(*operands)
# else:
# out = false_fn(*operands)
out = torch.zeros([])
Copy link
Contributor

@zou3519 zou3519 Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is going to fly, if we're doing real tracing (make_fx(f, mode='real')), it is supposed to end up executing real operations on the Tensor.

If we're tracing a function that calls cond and then passes the output of cond to another pytorch operation (like torch.sum), then torch.sum is going to receive this dummy Tensor and then will actually execute out.sum(). This will be a problem if the other operation is something that does rely on the output to be a real Tensor.

Here's an hacky idea that might work:

  • create new_true_fn. new_true_fn executes true_fn and then stores the output of true_fn somewhere.
saved_true_out = []
def new_true_fn(*operands):
  out = true_fn(operands)
  saved_out.append(out)
  return out
  • Do true_graph = get_isolated_graphmodule(new_true_fn, operands, {}) (and ditto for new_false_fn).
  • Now, instead of out = torch.zeros([]), we can return either out = saved_true_out[0] or out = saved_false_out[0].
  • Furthermore, we can compare saved_true_out and saved_false_out for if "all the properties match" (though maybe that is what your checks above are doing, I'm not familiar with the code).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well just run the real thing without all the extra pizzaz around new_true_fn though. Just uncomment my commented code.

Will change.

Furthermore, we can compare saved_true_out and saved_false_out for if "all the properties match" (though maybe that is what your checks above are doing, I'm not familiar with the code).

We already do that above, yes.

for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert(true_out.meta == false_out.meta)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for now, but I'm not sure it's comprehensive. E.g. TensorMetadata doesn't seem to have layout (torch.strided, torch.sparse_coo).

Also nit: don't use parenthesis while calling assert: assert true_out.meta == false_out.meta. The reason is that it is easy to accidentally turn the parenthesis into a tuple (e.g. assert (true_out.meta == false_out.meta,) and assert on a tuple checks that it has at least one element

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta is {'tensor_meta': TensorMetadata(shape=torch.Size([4]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} for now - we can always add more to meta, right?

@malfet
Copy link
Contributor

malfet commented Aug 24, 2022

One needs to add empty dispatch/__init__.py otherwise it would not be packaged into neither wheel nor conda

@voznesenskym voznesenskym changed the title [WIP] [RFC] Torch cond operator, python dispatch, etc etc Torch cond operator, python dispatch, pyoperator Aug 24, 2022
@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@pytorchmergebot
Copy link
Collaborator

Merge failed
Reason: View failures on hud. Refusing to merge as mandatory check(s) Lint, pull failed for rule Core Maintainers.
Raised by workflow job

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link

Hey @voznesenskym.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit to pytorch/functorch that referenced this pull request Aug 27, 2022
Summary:
Fixes #ISSUE_NUMBER

X-link: pytorch/pytorch#83154
Approved by: https://github.com/ezyang

Reviewed By: weiwangmeta

Differential Revision: D39034501

Pulled By: voznesenskym

fbshipit-source-id: 7be6caa7e3c7345f50671a7cef8a5fd0b2565a21
facebook-github-bot pushed a commit that referenced this pull request Aug 27, 2022
Summary:
Fixes #ISSUE_NUMBER

Pull Request resolved: #83154
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ced2ca8f867b376c5b4e495f183aeba78a27c0c4

Reviewed By: weiwangmeta

Differential Revision: D39034501

Pulled By: voznesenskym

fbshipit-source-id: 7be6caa7e3c7345f50671a7cef8a5fd0b2565a21
@github-actions github-actions bot deleted the voz/ctfl_proto branch February 25, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants