Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTInductor] Implement autograd eager backend for native triton kernels #110403

Closed
wants to merge 5 commits into from

Conversation

@oulgen oulgen requested review from a team, albanD and soulitzer as code owners October 2, 2023 18:56
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110403

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b8a9266 with merge base efb73fe (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added release notes: fx release notes category labels Oct 2, 2023
oulgen added a commit that referenced this pull request Oct 2, 2023
ghstack-source-id: 1eecfac35d1796ee51e5f174fdafb796ea3b1f91
Pull Request resolved: #110403
@oulgen oulgen added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 2, 2023
if isinstance(proxy, fx.Proxy):
set_meta(proxy, e)

# example use case: allreduce_ returns ([tensor], work)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will update this comment

@albanD albanD removed their request for review October 2, 2023 21:55
…triton kernels"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 2, 2023
ghstack-source-id: e41cd757f7a3f29c44855fb1ba45cbdf8830bbcb
Pull Request resolved: #110403
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor nits

Comment on lines 50 to 54
# FX graph needs __name__ and __module__ attributes
fn.__name__ = func_overload.__name__ # type:ignore[attr-defined]
if not hasattr(fn, "__module__"):
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we unify this logic with the Dynamo logic that does something similar? Maybe put into a helper function somewhere? Also we might want to note somewhere that we are doing a hack to workaround how fx does not allow functions in the graph.

fn,
(),
proxy_args,
name="triton_kernel_wrapper_mutation",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the name should probably be overload.__name__; in the functional case it is better for the name of the output to be be triton_kernel_wrapper_functional than triton_kernel_wrapper_mutation

torch/_higher_order_ops/triton_kernel_wrap.py Show resolved Hide resolved
@@ -170,6 +170,15 @@ def from_functional(self):
torch._sync(self)
return torch._from_functional_tensor(self.elem)

def replace_(self, output) -> None:
torch._functionalize_replace(self.elem, output) # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: add the type annotation to https://github.com/pytorch/pytorch/blob/main/torch/_C/__init__.pyi.in to make mypy happy, but it looks like we didn't care much about that in this file


node_args = {"grid": grid, "kwargs": kwargs}
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment that is probably more relevant for the previous dynamo PR (although fine to deal with as a followup): It's probably worth having a test for the case where the triton kernel (and maybe the grid lambda) is an inner function. Does everything still work out? (The graph that we compile probably directly holds a reference that inner function, so will the lifetime of the inner function match what its lifetime would have been in eager mode?)

@torch.compile
def f(x, y):
        @triton.jit
        def kernel(in_ptr, out_ptr0, BLOCK_SIZE: "tl.constexpr"):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            ...
        # if grid/kernel hold onto some global state (like a large cuda tensor), then that memory will stick around until grid/kernel is cleaned up.
        return kernel[grid](x)
    ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will also do as follow up


def trace_triton_kernel_wrapper(proxy_mode, func_overload, *, kernel, grid, kwargs):
with disable_proxy_modes_tracing():
out = func_overload(kernel=kernel, grid=grid, kwargs=kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is always supposed to return None, right? Since this will redispatch to the FakeTensorMode implementation, which returns None (since this higher order op is only ever used for triton kernels).

Maybe add an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the mutation version returns none, but the functional version returns kwargs

return None


@triton_kernel_wrapper_mutation.py_functionalize_impl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, we're also relying on the fact that both grid and kernel will never mutate the metadata of any of the inputs. I think we said we'll (eventually) try to check for mutations from grid in dynamo and error. But is it possible for a triton kernel to mutate the metadata of a tensor input? e.g. in_tensor.transpose_(1, 0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the clone/replace for the triton kernel inputs, if the kernel mutates the metadata will that not be sufficient? @jansel any comments on this?

As for the grid, @jansel recommended doing something similar to what cond does for mutation checking, so i was planning on doing something like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This FakeTensorMode rule for the functional op:

    with mode:
        return {
            key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
            for key, val in kwargs.items()
        }

Is implicitly assuming that the outputs to the triton kernel always have the same metadata (sizes/strides/etc) of the inputs. So if we're handed a triton kernel that mutates sizes/strides of a tensor input, our shape prop logic will be wrong during tracing.

To be fair though - I'm imagining that it's either impossible or very rare for a triton kernel to do this - so if it is possible, the answer is probably to detect/ban it in dynamo.

return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)


@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I don't think we'll ever actually hit this in the "vanilla" torch.compile flow: functionalization runs above ProxyTorchDispatchMode, so we will always have converted triton_kernel_wrapper_mutation calls into triton_kernel_wrapper_functional calls in the graph before hitting the proxy mode.

Mostly just an fyi - still agreed that we need this (for example, pre_dispatch tracing today will hit this, since it does not (yet) use functionalization)

@@ -194,6 +194,18 @@ def wrap_with_proxy(e, proxy, constant):
# example use case: allreduce_ returns ([tensor], work)
for idx, ee in enumerate(e):
wrap_with_proxy(ee, proxy[idx], get_constant(idx))
elif isinstance(e, dict):
assert constant is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth a comment: in theory we could support const-prop when proxy-tensor-tracing operators that returns dicts of tensors, but we have no use case for it today (since the only op we currently trace that can return a dict is triton_kernel_wrapper_functional/mutation, which does not participate in const-prop)

@@ -1541,6 +1541,73 @@ def add_kernel(
# Make sure it is NOT modified
self.assertEqual(output, torch.zeros_like(t1))

@requires_cuda()
@requires_triton()
def test_triton_kernel_functionalize(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests are fine for this PR, but when we add inductor support it would be great to have some more comprehensive tests (just testing that the outputs match eager mode):

(1) triton kernel that mutates an input to the graph
(2) triton kernel that mutate a view of an input to the graph
(3) graph that returns a view/alias of one of the mutated arguments to a triton kernel

For the first case (simple input mutation), there are ~3 individual cases worth testing. This is because these cases impact whether or not AOTAutograd sends the copy_() (input mutation) directly to inductor in the graph, so they affect whether or not inductor will be able to optimize away the clone.

(a) the graph is an inference graph (test is under no_grad(), or no inputs to the graph require gradient)
(b) graph is a training graph (at least one input requires grad), but the arguments to the triton kernel to not require grad
(c) graph is a training graph, and the argument(s) to the triton kernel also require grad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add as follow up. Although I need some pointers for how to implement some of these test cases.

Right now all my tests are (1) I believe and none of my tensors require grad (i did not do require_grad=True).

  • How do make inputs to triton kernel require grad?
  • Do I need to test things differently when tensors require grad or do I still check outputs match?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do make inputs to triton kernel require grad?

You can just set your tensor inputs to the compiled code to require grad:

@torch.compile
def f(x):
    return kernel[grid](x, ...)

x = torch.randn(4, requires_grad=True)
out = f(x)

Do I need to test things differently when tensors require grad or do I still check outputs match?

Testing just the forward is fine - but you're right, a more comprehensive test would involve also running autograd and checking that the gradients match between eager and compile. Example:

@torch.compile
def f_triton_input_mutation_of_view(x, out):
    out_view = out[:, 1]
    kernel[grid](x=x, out=out_view)
    return torch.mul(out_view, 2)

x = torch.randn(4)
out = torch.randn(4, requires_grad=True)
# run autograd
out.sum().backward()

# Run the same with ref inputs on eager, check that the grads match
self.assertEqual(out.grad, out_ref.grad)

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! Left some minor comments that don't all need to be addressed in this PR

…triton kernels"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 3, 2023
ghstack-source-id: 295f7ad5da97982c28ff91e88874053217c9c66b
Pull Request resolved: #110403
…triton kernels"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 3, 2023
ghstack-source-id: 7623fe472754dfef821250f3c5c49f4abdb30ab4
Pull Request resolved: #110403
…triton kernels"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
@oulgen
Copy link
Contributor Author

oulgen commented Oct 4, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 4, 2023
Pull Request resolved: #110486
Approved by: https://github.com/jansel
ghstack dependencies: #110403
@facebook-github-bot facebook-github-bot deleted the gh/oulgen/3/head branch October 8, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants