-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Added host-side associative scan function #129307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129307
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 59a85ca with merge base 0738916 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without proper parallelisation this is going to be very slow. Can you post benchmarks?
Thank you for your comment. I tried to carry out two benchmarks, one for the case of import torch
from torch._higher_order_ops.associative_scan import associative_scan as device_scan
import torch._dynamo.config as config
import jax
import jax.numpy as jnp
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np
config.cache_size_limit = 10000
torch.set_default_device('cuda')
def associative_scan_fct(operator, input, dim=0, reverse=False):
inp_leaves, spec = pytree.tree_flatten(input)
result_flat = []
num_leaves = len(inp_leaves)
op = reversed if reverse else lambda x: x
for ind in op(range(inp_leaves[0].size(dim))):
r = [
inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
for leave_ind in range(num_leaves)
]
if (ind > 0 and not reverse) or (
ind < (inp_leaves[0].size(dim) - 1) and reverse
):
r = operator(
pytree.tree_unflatten(result_flat[-1], spec),
pytree.tree_unflatten(r, spec),
)
r_flat, _ = pytree.tree_flatten(r)
result_flat.append(r_flat)
results = [
torch.stack([e[leave_ind] for e in op(result_flat)], dim)
for leave_ind in range(num_leaves)
]
return pytree.tree_unflatten(results, spec)
comp_host_scan = torch.compile(device_scan)
comp_cumsum = torch.compile(torch.cumsum)
jax_scan = jax.lax.associative_scan
compiled_jax_scan = jax.jit(jax_scan, static_argnums=(0,))
def scan_fct(i, j):
return i + j
mm_inps = lambda S, B, construct: construct(S, B, B)
input_func = mm_inps
jax_construct = lambda *shape: jnp.ones(shape, dtype=jnp.float32)
torch_construct = lambda *shape: torch.ones(*shape, dtype=torch.float32)
B = 100
jax_times_compiled = []
eager_times = []
compiled_host_side_times = []
compiled_device_side_times = []
upper_limit = 15000
lower_limit = 2000
for S in range(lower_limit, upper_limit, 2000):
inp = input_func(S, B, torch_construct)
jax_inp = input_func(S, B, jax_construct)
eager_times.append(do_bench(lambda: associative_scan_fct(scan_fct, inp, dim=0)))
compiled_host_side_times.append(do_bench(lambda: comp_host_scan(scan_fct, inp, dim=0, reverse=False, host_side=True)))
compiled_device_side_times.append(do_bench(lambda: comp_cumsum(inp, 0)))
jax_times_compiled.append(do_bench(lambda: compiled_jax_scan(scan_fct, jax_inp)))
print("eager: ", eager_times)
print("compiled host-side: ", compiled_host_side_times)
print("compiled device-side: ", compiled_device_side_times)
print("jax compiled: ", jax_times_compiled)
plt.figure()
# plt.plot(np.arange(lower_limit, upper_limit, 2000), eager_times, label='Eager')
plt.plot(np.arange(lower_limit, upper_limit, 2000), compiled_host_side_times, label='Associative scan host-side compiled')
plt.plot(np.arange(lower_limit, upper_limit, 2000), compiled_device_side_times, label='Associative scan device-side compiled')
plt.plot(np.arange(lower_limit, upper_limit, 2000), jax_times_compiled, label='Jax compiled')
plt.xlabel('Length')
plt.ylabel('Time [s]')
plt.legend()
plt.show() ![]() For the
![]() As a side note, this code is largely based on the code from @Chillee. |
Yeah, I don't have that much of an issue with the scaling with the batch size. I do have it with the scan dimension, which is where the work is not done in parallel. |
@lezcano This is parallelized along the scan dimension - the scan elements are batched together using a parallel scan algorithm. |
Is there any way that one can introduce more parallelism or any other way to make it more efficient? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see in which sense this is parallel. Given how different it is to the algorithm used on GPU, it would still be good to benchmark the operation varying the batched dimension to see what performance characteristics it has.
Overall the idea looks good, but I left a few comments.
Also, I see no reason why this algorithm should just be used on CPU. We could also use it to support more generic scans on device, as triton just supports pointwise ops. |
For general scans:
|
@lezcano Thank you very much for your comments. Indeed, this feature is not intended to be for CPU only. The name host-side might be misleading here. The feature performs the associative scan operation on the host-side in the sense that it does not utilize the triton, e.g. tl.associative_scan, or device-specific features. The feature can be used with GPUs and with CPUs. I a separate testcase for that as well. @vadimkantorov This feature partly fills this gap if the RNN that you are feeding is linear, e.g., it can be performed in an associative way. The feature for an arbitrary RNN is certainly something that I am also highly interested, but the closest that feature that I can see in this direction is a while_loop operator, where the body is your arbitrary RNN. Your second idea sounds really interesting, but I wouldn't really know how to do this. Could you help me out with this? |
Regarding on how to unroll the loop and when we can compile several loop iterations in a kernel - I don't know enough on how to implement this in torch.compile/Inductor. But I heard at some conference that someone used this trick to accelerate WaveRNN inference or sth like that. I guess memory placement optimization can be important here too and maybe for small enough RNN states (like 32-dim or 64-dim or 128-dim like in https://github.com/NVlabs/tiny-cuda-nn) some extra memory optimizations can be done This PR is also super-useful for wider experimentation with linear RNNs, because Triton lowering only supports elementwise cells |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice if we could use this mode automatically as a fallback, but we currently can't, because in this mode we have a different API when we have multiple inputs-outputs. In the generic_scan=False
case we need to unpack the inputs and outputs (combine_op
takes 2n inputs and returns n outputs) while in the generic_scan=True
the inputs are packed.
Do you reckon it'd be possible to reconcile both APIs? If so, it'd be great to remove the flag, and use generic_scan=False
if the inputs are in cuda and the combine op is all pointwise ops, and otherwise use the other one as a fallback.
If the generic scan mode is ``used``, the restrictions on the combine_fn are less strict, | ||
for example also non-pointwise operations are allowed. | ||
However, the non-generic scan mode utilizes efficient ``tl.associative_scan`` calls and is thus | ||
more efficient as the generic scan mode. | ||
|
||
Example:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example of generic_scan=True
(one with matmul would be good) and one with two inputs and outputs (e.g. something like cummin
).
In general, I would like to show that the API for generic_scan=True
and =False
is different (!!) when there are multiple inputs and outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to unify the API and hence this may not be necessary anymore?
@@ -63,8 +92,15 @@ def associative_scan( | |||
This function must be pure, pointwise, and satisfy the associative property. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment in line 86 above is stale, can you update it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Fixed typos Updated documentation
I tried to unify the
@lezcano please let me know whether the checks here are appropriate. The problem that I have with the pointwise check is that it does not work under Furthermore, I spotted an issue when the device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I think the current code that looks into whether a function is pointwise may be not general enough. At the moment, these checks are done in lower_pointwise_subgraph
, within the lowering of this OP. The ideas I have to mix both approaches are not super clean.
Pinging @peterbell10 see if he has any ideas on what's the best approach here, and for a general review / help with those issues you are seeing.
On a different note, if we are able to choose the algorithm cleanly, I would propose to not expose the flag now in the interest of "it's easy to add a flag later if people request it, but it's very difficult to deprecate a given flag".
Also, I assume that this is BC breaking to some extent, as it changes the way associative_scan
works for multiple inputs? I think it's fine to do this, as we are not even rendering the docs for this OP (which we should fix at some point... not this PR tho)
If the auto check for pointwiseness is too complex, better to actually intoduce it as a separate function (instead of a flag)? |
I think we should explicitly expose a flag for whether it's device-side or host-side. Even if it's all pointwise ops users may want to try the host-side scan regardless. |
Yeah, so that's what I initially thought. The user can use this flag to 'force' the |
), "The pytree of the output of the operator needs to match the input pytree" | ||
|
||
generic_scan_required = False | ||
combine_fn = make_fx(combine_fn)(pytree.tree_unflatten( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to run this check under dynamo? I would have thought we could call associative_scan_op
, and inside the compiler perform this check and if it fails we use the eager implementation as a decomposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review and your comment. Unfortunately, I don't fully understand it and I was wondering whether you could elaborate a bit more on it for me?
I was under the impression that the goal is to have a generic_scan
(which can either run in eager
mode or under torch.compile
) that is more general than the optimized associative scan that invokes tl.associative_scan
, e.g., we don't call the associative_scan_op
at the moment. The idea was then to determine apriori which version of scan to use, the generic scan or the non generic scan. This involves checking whether the combine_fn
contains only pointwise operations and I didn't know about any better way to do this, but this check fails under dynamo. I am not sure how and where you envision to move this check?
Also, regarding your comment below. I just used a similar methodology as in other HOPs, such as cond. I thought that the result should be that the associative_scan
function is always compiled and that's why this if clause is placed there. Are you suggesting to remove this and expect the user to invoke torch.compile
instead, because there could be cases where one doesn't want to have this function compiled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how and where you envision to move this check?
Inside the compiler when compiling associative_scan_op
we can choose to fallback to the generic implementation instead. Ideally, this could be done in torch/_inductor/decompositions.py
.
Also, regarding your comment below. I just used a similar methodology as in other HOPs, such as cond
Looks like cond is being compiled with the eager backend there which doesn't make a great deal of sense to me tbh.
cc @zou3519 @aakhundov any comment on why HOPs are being compiled with eager backend instead of just running as normal eager code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydwu4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any comment on why HOPs are being compiled with eager backend instead of just running as normal eager code
We rely on dynamo to 1. check and raise error when there're any side-effects in the udf 2. lift closures as implicit inputs 3. flatten/unflatten pytree. We also considered to get rid of torch.compile, but no concrete conclusions yet @Chillee .
combine_fn, input, dim | ||
) | ||
|
||
return torch.compile(associative_scan, fullgraph=False)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have an eager implementation then we shouldn't need to compile this.
@peterbell10
Thank you in advance |
Extended testcase of torch.flip with data types and backends
WIP: Issue with view operations in case of matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went through the diff and left some comments. Can do a rebase? So I could let the CI run.
@@ -5501,6 +5501,140 @@ def test_optimized_module_training(self): | |||
mod.eval() | |||
self.assertFalse(opt_mod.training) | |||
|
|||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test related to scan? Can we delete this from this PR? Maybe can keep it locally or in a gist if you want to share it.
def add(x: torch.Tensor, y: torch.Tensor): | ||
return x + y | ||
|
||
for device in [torch.device("cpu"), torch.device("cuda")]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The standard way of doing this is adding @parametrize
decorator. There are examples in this file to follow. Can you also change other tests to use parametrize
print('Flip test fails for backends: ' + str(fails_for_backend)) | ||
self.assertEqual(len(fails_for_backend), 0) | ||
|
||
for n in range(20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we want to run 20 times?
}, | ||
)(fct, x, 0, False, False) | ||
self.assertExpectedInline( | ||
gm.print_readable(print_output=False).strip(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the graph produced by print_readable(print_output=False) sometimes is not stable across environment. Can change them to gm.code.strip()
gm = make_fx( | ||
scan, | ||
decomposition_table={ | ||
associative_scan_op: torch._inductor.decomposition.associative_scan_op_decomp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this a separate test?
@@ -37,6 +37,23 @@ def promote_to_tensor(x): | |||
return x + tl.zeros((1,), tl.int1) | |||
|
|||
|
|||
@triton.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these changes necessary?
@@ -21,10 +21,11 @@ | |||
from torch.fx.experimental.symbolic_shapes import ShapeEnv | |||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing | |||
from torch.utils._sympy.symbol import symbol_is_type, SymT | |||
from torch.utils._sympy.value_ranges import bound_sympy | |||
from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on in this file?
@@ -52,7 +52,7 @@ | |||
if TYPE_CHECKING: | |||
import types | |||
|
|||
from torch._ops import OpOverload | |||
from torch._ops import OperatorBase, OpOverload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here? Should be a separate pr?
@@ -322,7 +322,7 @@ class ResolvedExportOptions(ExportOptions): | |||
onnx_registry: OnnxRegistry | |||
|
|||
# Private only attributes | |||
decomposition_table: dict[torch._ops.OpOverload, Callable] | |||
decomposition_table: dict[torch._ops.OperatorBase | torch._ops.OpOverload, Callable] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate pr?
@@ -22,7 +22,7 @@ def __init__( | |||
self, | |||
diagnostic_context: diagnostics.DiagnosticContext, | |||
module: torch.fx.GraphModule, | |||
decomposition_table: Mapping[torch._ops.OpOverload, Callable], | |||
decomposition_table: Mapping[torch._ops.OperatorBase | torch._ops.OpOverload, Callable], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `reverse` flag to the `associative_scan` to establish a similar interface as for `jax.associative_scan`. This PR has been derived from #129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: #133011 Approved by: https://github.com/ydwu4
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from #129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: #133012 Approved by: https://github.com/ydwu4
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `reverse` flag to the `associative_scan` to establish a similar interface as for `jax.associative_scan`. This PR has been derived from #129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: #133011 Approved by: https://github.com/ydwu4
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from pytorch#129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: pytorch#133012 Approved by: https://github.com/ydwu4
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This PR aims to implement a host-side associative scan functionality. This functionality does solely rely on PyTorch compiled code and does not explicitly invoke
tl.associative_scan
yet.This is related to #95408
@Chillee
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec