Skip to content

Conversation

bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Jun 22, 2024

This PR aims to implement a host-side associative scan functionality. This functionality does solely rely on PyTorch compiled code and does not explicitly invoke tl.associative_scan yet.

This is related to #95408

@Chillee

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

Copy link

pytorch-bot bot commented Jun 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129307

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 59a85ca with merge base 0738916 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl bohnstingl marked this pull request as ready for review June 22, 2024 14:08
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without proper parallelisation this is going to be very slow. Can you post benchmarks?

@bohnstingl
Copy link
Collaborator Author

bohnstingl commented Jun 23, 2024

Thank you for your comment. I tried to carry out two benchmarks, one for the case of cumsum and one for mm. I compared to a naive implementation, with torch.cumsum and with JAX. The cumsum code that I've used is

import torch
from torch._higher_order_ops.associative_scan import associative_scan as device_scan
import torch._dynamo.config as config
import jax
import jax.numpy as jnp
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np

config.cache_size_limit = 10000
torch.set_default_device('cuda')

def associative_scan_fct(operator, input, dim=0, reverse=False):
    inp_leaves, spec = pytree.tree_flatten(input)
    result_flat = []
    num_leaves = len(inp_leaves)
    op = reversed if reverse else lambda x: x

    for ind in op(range(inp_leaves[0].size(dim))):
        r = [
            inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
            for leave_ind in range(num_leaves)
        ]
        if (ind > 0 and not reverse) or (
            ind < (inp_leaves[0].size(dim) - 1) and reverse
        ):
            r = operator(
                pytree.tree_unflatten(result_flat[-1], spec),
                pytree.tree_unflatten(r, spec),
            )
        r_flat, _ = pytree.tree_flatten(r)
        result_flat.append(r_flat)
    results = [
        torch.stack([e[leave_ind] for e in op(result_flat)], dim)
        for leave_ind in range(num_leaves)
    ]
    return pytree.tree_unflatten(results, spec)

comp_host_scan = torch.compile(device_scan)
comp_cumsum = torch.compile(torch.cumsum)

jax_scan = jax.lax.associative_scan
compiled_jax_scan = jax.jit(jax_scan, static_argnums=(0,))

def scan_fct(i, j):
    return i + j

mm_inps = lambda S, B, construct: construct(S, B, B)
input_func = mm_inps

jax_construct = lambda *shape: jnp.ones(shape, dtype=jnp.float32)
torch_construct = lambda *shape: torch.ones(*shape, dtype=torch.float32)
B = 100
jax_times_compiled = []
eager_times = []
compiled_host_side_times = []
compiled_device_side_times = []
upper_limit = 15000
lower_limit = 2000
for S in range(lower_limit, upper_limit, 2000):
    inp = input_func(S, B, torch_construct)
    jax_inp = input_func(S, B, jax_construct)

    eager_times.append(do_bench(lambda: associative_scan_fct(scan_fct, inp, dim=0)))
    compiled_host_side_times.append(do_bench(lambda: comp_host_scan(scan_fct, inp, dim=0, reverse=False, host_side=True)))
    compiled_device_side_times.append(do_bench(lambda: comp_cumsum(inp, 0)))
    jax_times_compiled.append(do_bench(lambda: compiled_jax_scan(scan_fct, jax_inp)))

print("eager: ", eager_times)
print("compiled host-side: ", compiled_host_side_times)
print("compiled device-side: ", compiled_device_side_times)
print("jax compiled: ", jax_times_compiled)

plt.figure()
# plt.plot(np.arange(lower_limit, upper_limit, 2000), eager_times, label='Eager')
plt.plot(np.arange(lower_limit, upper_limit, 2000), compiled_host_side_times, label='Associative scan host-side compiled')
plt.plot(np.arange(lower_limit, upper_limit, 2000), compiled_device_side_times, label='Associative scan device-side compiled')
plt.plot(np.arange(lower_limit, upper_limit, 2000), jax_times_compiled, label='Jax compiled')
plt.xlabel('Length')
plt.ylabel('Time [s]')
plt.legend()
plt.show()
image I acknowledge that the host-side scan is slower compared to JAX or torch.cumsum in this scenario, but it is way faster compared to an eager implementation in pure PyTorch (which I measured, but didn't even plot).

For the mm scenario, the numbers look even better, with the host-side scan being slightly faster than JAX.

import torch
from torch._higher_order_ops.associative_scan import associative_scan as device_scan
import torch._dynamo.config as config
import jax
import jax.numpy as jnp
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np

config.cache_size_limit = 10000
torch.set_default_device('cuda')

def associative_scan_fct(operator, input, dim=0, reverse=False):
    inp_leaves, spec = pytree.tree_flatten(input)
    result_flat = []
    num_leaves = len(inp_leaves)
    op = reversed if reverse else lambda x: x

    for ind in op(range(inp_leaves[0].size(dim))):
        r = [
            inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
            for leave_ind in range(num_leaves)
        ]
        if (ind > 0 and not reverse) or (
            ind < (inp_leaves[0].size(dim) - 1) and reverse
        ):
            r = operator(
                pytree.tree_unflatten(result_flat[-1], spec),
                pytree.tree_unflatten(r, spec),
            )
        r_flat, _ = pytree.tree_flatten(r)
        result_flat.append(r_flat)
    results = [
        torch.stack([e[leave_ind] for e in op(result_flat)], dim)
        for leave_ind in range(num_leaves)
    ]
    return pytree.tree_unflatten(results, spec)

comp_device_scan = torch.compile(device_scan)

jax_scan = jax.lax.associative_scan
compiled_jax_scan = jax.jit(jax_scan, static_argnums=(0,))

def mm_chain_scan_jax(i, j):
    return jax.vmap(jnp.matmul)(i, j)

def mm_chain_scan_torch(i, j):
    return torch.vmap(torch.mm)(i, j)

def mm_chain_scan_torch_wovmap(i, j):
    return torch.mm(i, j)

mm_inps = lambda S, B, construct: construct(S, B, B)
scan_func = mm_chain_scan_torch
scan_func_wovmap = mm_chain_scan_torch_wovmap
scan_func_jax = mm_chain_scan_jax
input_func = mm_inps

jax_construct = lambda *shape: jnp.ones(shape, dtype=jnp.float32)
torch_construct = lambda *shape: torch.ones(*shape, dtype=torch.float32)
B = 100
jax_times_compiled = []
eager_times = []
compiled_host_side_times = []
upper_limit = 15000
lower_limit = 2000
for S in range(lower_limit, upper_limit, 2000):
    inp = input_func(S, B, torch_construct)
    jax_inp = input_func(S, B, jax_construct)

    eager_times.append(do_bench(lambda: associative_scan_fct(scan_func_wovmap, inp, dim=0)))
    compiled_host_side_times.append(do_bench(lambda: comp_device_scan(scan_func, inp, dim=0, reverse=False, host_side=True)))
    jax_times_compiled.append(do_bench(lambda: compiled_jax_scan(scan_func_jax, jax_inp)))

print("eager: ", eager_times)
print("compiled host-side: ", compiled_host_side_times)
print("jax compiled: ", jax_times_compiled)

plt.figure()
plt.plot(np.arange(lower_limit, upper_limit, 2000), compiled_host_side_times, label='Associative scan host-side compiled')
plt.plot(np.arange(lower_limit, upper_limit, 2000), jax_times_compiled, label='Jax compiled')
plt.legend()
plt.show()
image

As a side note, this code is largely based on the code from @Chillee.

@lezcano
Copy link
Collaborator

lezcano commented Jun 24, 2024

Yeah, I don't have that much of an issue with the scaling with the batch size. I do have it with the scan dimension, which is where the work is not done in parallel.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 24, 2024
@Chillee
Copy link
Collaborator

Chillee commented Jun 24, 2024

@lezcano This is parallelized along the scan dimension - the scan elements are batched together using a parallel scan algorithm.

@bohnstingl
Copy link
Collaborator Author

Is there any way that one can introduce more parallelism or any other way to make it more efficient?

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see in which sense this is parallel. Given how different it is to the algorithm used on GPU, it would still be good to benchmark the operation varying the batched dimension to see what performance characteristics it has.

Overall the idea looks good, but I left a few comments.

@lezcano
Copy link
Collaborator

lezcano commented Jun 25, 2024

Also, I see no reason why this algorithm should just be used on CPU. We could also use it to support more generic scans on device, as triton just supports pointwise ops.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 25, 2024

For general scans:

  • it would be good (even if inefficient) supporting arbitrary cells (e.g. passing a custom RNN cell to the scan primitive and getting a somewhat perf/memory efficient custom RNN model) - this PR fills this gap, right?
  • it might be good to support loop unrolling during inference or CUDA codegen: e.g. compile together several iterations of the loop to minimize number of kernel launches (maybe this could be some sort of torch.compile arg) for long sequences, in this case we could have number of kernels reduced to e.g. SEQLEN/8, and for sequences the per-step kernel launch overhead can be signficant

@bohnstingl
Copy link
Collaborator Author

@lezcano Thank you very much for your comments. Indeed, this feature is not intended to be for CPU only. The name host-side might be misleading here. The feature performs the associative scan operation on the host-side in the sense that it does not utilize the triton, e.g. tl.associative_scan, or device-specific features. The feature can be used with GPUs and with CPUs. I a separate testcase for that as well.

@vadimkantorov This feature partly fills this gap if the RNN that you are feeding is linear, e.g., it can be performed in an associative way. The feature for an arbitrary RNN is certainly something that I am also highly interested, but the closest that feature that I can see in this direction is a while_loop operator, where the body is your arbitrary RNN.

Your second idea sounds really interesting, but I wouldn't really know how to do this. Could you help me out with this?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 25, 2024

Regarding on how to unroll the loop and when we can compile several loop iterations in a kernel - I don't know enough on how to implement this in torch.compile/Inductor. But I heard at some conference that someone used this trick to accelerate WaveRNN inference or sth like that. I guess memory placement optimization can be important here too and maybe for small enough RNN states (like 32-dim or 64-dim or 128-dim like in https://github.com/NVlabs/tiny-cuda-nn) some extra memory optimizations can be done

This PR is also super-useful for wider experimentation with linear RNNs, because Triton lowering only supports elementwise cells

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice if we could use this mode automatically as a fallback, but we currently can't, because in this mode we have a different API when we have multiple inputs-outputs. In the generic_scan=False case we need to unpack the inputs and outputs (combine_op takes 2n inputs and returns n outputs) while in the generic_scan=True the inputs are packed.
Do you reckon it'd be possible to reconcile both APIs? If so, it'd be great to remove the flag, and use generic_scan=False if the inputs are in cuda and the combine op is all pointwise ops, and otherwise use the other one as a fallback.

If the generic scan mode is ``used``, the restrictions on the combine_fn are less strict,
for example also non-pointwise operations are allowed.
However, the non-generic scan mode utilizes efficient ``tl.associative_scan`` calls and is thus
more efficient as the generic scan mode.

Example::
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an example of generic_scan=True (one with matmul would be good) and one with two inputs and outputs (e.g. something like cummin).
In general, I would like to show that the API for generic_scan=True and =False is different (!!) when there are multiple inputs and outputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to unify the API and hence this may not be necessary anymore?

@@ -63,8 +92,15 @@ def associative_scan(
This function must be pure, pointwise, and satisfy the associative property.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in line 86 above is stale, can you update it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@bohnstingl
Copy link
Collaborator Author

bohnstingl commented Jun 26, 2024

I tried to unify the generic_scan=False and generic_scan=True case. The default is generic_scan=False, but there are now checks internally to determine whether the generic_scan has to be used. Specifically generic_scan=True if:

  1. Not all tensors are on CUDA
  2. The operations inside the graph of the function is used by a non-pointwise function
    However, I thought that it might be still good that the user can force the generic_scan=True behavior, so the user can override generic_scan to be true.

@lezcano please let me know whether the checks here are appropriate.

The problem that I have with the pointwise check is that it does not work under torch.compile(). The make_fx to create an FX graph does not work in the compile environment. I am wondering whether there is a better way to do this?

Furthermore, I spotted an issue when the device CUDA is used. I tried to dig into this a little bit here and it appears that for some lengths of the input tensors, the results are sporadically wrong and I don't know why. I also added TODOs in the testcases that can reproduce this behavior. Could you help me out with this one?

@lezcano lezcano requested a review from peterbell10 June 27, 2024 07:22
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think the current code that looks into whether a function is pointwise may be not general enough. At the moment, these checks are done in lower_pointwise_subgraph, within the lowering of this OP. The ideas I have to mix both approaches are not super clean.
Pinging @peterbell10 see if he has any ideas on what's the best approach here, and for a general review / help with those issues you are seeing.

On a different note, if we are able to choose the algorithm cleanly, I would propose to not expose the flag now in the interest of "it's easy to add a flag later if people request it, but it's very difficult to deprecate a given flag".

Also, I assume that this is BC breaking to some extent, as it changes the way associative_scan works for multiple inputs? I think it's fine to do this, as we are not even rendering the docs for this OP (which we should fix at some point... not this PR tho)

@vadimkantorov
Copy link
Contributor

If the auto check for pointwiseness is too complex, better to actually intoduce it as a separate function (instead of a flag)?

@Chillee
Copy link
Collaborator

Chillee commented Jun 27, 2024

I think we should explicitly expose a flag for whether it's device-side or host-side. Even if it's all pointwise ops users may want to try the host-side scan regardless.

@bohnstingl
Copy link
Collaborator Author

bohnstingl commented Jun 27, 2024

Yeah, so that's what I initially thought. The user can use this flag to 'force' the generic_scan. The other way around may not make so much sense, as the generic_scan=False has limitations.
However, the aspects with the BC and the "it's easy to add a flag later if people request it, but it's very difficult to deprecate a given flag" are also strong arguments for me.

), "The pytree of the output of the operator needs to match the input pytree"

generic_scan_required = False
combine_fn = make_fx(combine_fn)(pytree.tree_unflatten(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to run this check under dynamo? I would have thought we could call associative_scan_op, and inside the compiler perform this check and if it fails we use the eager implementation as a decomposition.

Copy link
Collaborator Author

@bohnstingl bohnstingl Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review and your comment. Unfortunately, I don't fully understand it and I was wondering whether you could elaborate a bit more on it for me?
I was under the impression that the goal is to have a generic_scan (which can either run in eager mode or under torch.compile) that is more general than the optimized associative scan that invokes tl.associative_scan, e.g., we don't call the associative_scan_op at the moment. The idea was then to determine apriori which version of scan to use, the generic scan or the non generic scan. This involves checking whether the combine_fn contains only pointwise operations and I didn't know about any better way to do this, but this check fails under dynamo. I am not sure how and where you envision to move this check?

Also, regarding your comment below. I just used a similar methodology as in other HOPs, such as cond. I thought that the result should be that the associative_scan function is always compiled and that's why this if clause is placed there. Are you suggesting to remove this and expect the user to invoke torch.compile instead, because there could be cases where one doesn't want to have this function compiled?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how and where you envision to move this check?

Inside the compiler when compiling associative_scan_op we can choose to fallback to the generic implementation instead. Ideally, this could be done in torch/_inductor/decompositions.py .

Also, regarding your comment below. I just used a similar methodology as in other HOPs, such as cond

Looks like cond is being compiled with the eager backend there which doesn't make a great deal of sense to me tbh.

cc @zou3519 @aakhundov any comment on why HOPs are being compiled with eager backend instead of just running as normal eager code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ydwu4

Copy link
Contributor

@ydwu4 ydwu4 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any comment on why HOPs are being compiled with eager backend instead of just running as normal eager code

We rely on dynamo to 1. check and raise error when there're any side-effects in the udf 2. lift closures as implicit inputs 3. flatten/unflatten pytree. We also considered to get rid of torch.compile, but no concrete conclusions yet @Chillee .

combine_fn, input, dim
)

return torch.compile(associative_scan, fullgraph=False)(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have an eager implementation then we shouldn't need to compile this.

@bohnstingl
Copy link
Collaborator Author

@peterbell10
I was wondering whether you could help me out on two issues that this feature currently is stuck on

  1. Issue with the “random” results that I observe when using the CUDA device: There is a testcase that shows that for some input length the results between the compiled and the non-compiled version differ. This is only the case though for the CUDA device. For the CPU device, both behaviors are the same. You can find my current investigations on this in the thread above
  2. Issue with the checks that the combine_fn only contains pointwise operations. You mentioned that you wanted to do the compilation or the checks in a different form, but I don't quite know what you had in mind. Could you maybe give me some guidance?

Thank you in advance

Extended testcase of torch.flip with data types and backends
@bohnstingl bohnstingl requested review from lezcano and ydwu4 July 30, 2024 07:10
WIP: Issue with view operations in case of matmul
Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went through the diff and left some comments. Can do a rebase? So I could let the CI run.

@@ -5501,6 +5501,140 @@ def test_optimized_module_training(self):
mod.eval()
self.assertFalse(opt_mod.training)

@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test related to scan? Can we delete this from this PR? Maybe can keep it locally or in a gist if you want to share it.

def add(x: torch.Tensor, y: torch.Tensor):
return x + y

for device in [torch.device("cpu"), torch.device("cuda")]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard way of doing this is adding @parametrize decorator. There are examples in this file to follow. Can you also change other tests to use parametrize

print('Flip test fails for backends: ' + str(fails_for_backend))
self.assertEqual(len(fails_for_backend), 0)

for n in range(20):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we want to run 20 times?

},
)(fct, x, 0, False, False)
self.assertExpectedInline(
gm.print_readable(print_output=False).strip(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the graph produced by print_readable(print_output=False) sometimes is not stable across environment. Can change them to gm.code.strip()

gm = make_fx(
scan,
decomposition_table={
associative_scan_op: torch._inductor.decomposition.associative_scan_op_decomp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a separate test?

@@ -37,6 +37,23 @@ def promote_to_tensor(x):
return x + tl.zeros((1,), tl.int1)


@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes necessary?

@@ -21,10 +21,11 @@
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy
from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on in this file?

@@ -52,7 +52,7 @@
if TYPE_CHECKING:
import types

from torch._ops import OpOverload
from torch._ops import OperatorBase, OpOverload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Should be a separate pr?

@@ -322,7 +322,7 @@ class ResolvedExportOptions(ExportOptions):
onnx_registry: OnnxRegistry

# Private only attributes
decomposition_table: dict[torch._ops.OpOverload, Callable]
decomposition_table: dict[torch._ops.OperatorBase | torch._ops.OpOverload, Callable]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate pr?

@@ -22,7 +22,7 @@ def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
module: torch.fx.GraphModule,
decomposition_table: Mapping[torch._ops.OpOverload, Callable],
decomposition_table: Mapping[torch._ops.OperatorBase | torch._ops.OpOverload, Callable],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

pytorchmergebot pushed a commit that referenced this pull request Aug 16, 2024
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `reverse` flag to the `associative_scan` to establish a similar interface as for `jax.associative_scan`. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

Pull Request resolved: #133011
Approved by: https://github.com/ydwu4
pytorchmergebot pushed a commit that referenced this pull request Aug 30, 2024
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

Pull Request resolved: #133012
Approved by: https://github.com/ydwu4
pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `reverse` flag to the `associative_scan` to establish a similar interface as for `jax.associative_scan`. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

Pull Request resolved: #133011
Approved by: https://github.com/ydwu4
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from pytorch#129307.

@ydwu4 @Chillee @zou3519

Pull Request resolved: pytorch#133012
Approved by: https://github.com/ydwu4
Copy link
Contributor

github-actions bot commented Oct 7, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 7, 2024
@bohnstingl
Copy link
Collaborator Author

Closing this PR.
It has been split into several parts and some of them have landed to main already.
Closed:
#133011
#133012

Open:
#134921
#133013
#136966

@bohnstingl bohnstingl closed this Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source release notes: fx release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants