-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add a decomposition for torch.put #115306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115306
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit 876e0fd with merge base e5f2ac1 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_decomp/decompositions.py
Outdated
flattened = self.flatten().clone() | ||
if accumulate: | ||
flattened[index] += source | ||
else: | ||
flattened[index] = source |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mutation isn't legal inside decompositions of functional operators. If I'm not mistaken you could just use torch.index_put
though.
flattened = self.flatten().clone() | |
if accumulate: | |
flattened[index] += source | |
else: | |
flattened[index] = source | |
flattened = self.flatten() | |
flattened = torch.index_put(flattened, [index], source, accumulate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this fails because put
supports NumPy-style "reverse broadcasting":
>>> x = torch.tensor(1.0)
>>> index = torch.tensor(0)
>>> source = torch.tensor([0.0])
>>> torch.put(x, index, source)
tensor(0.)
(which is IMO a bad idea, but it is tested)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess source.reshape(index.shape)
should work, assuming the actual shape conditions are all tested separately.
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
Benchmark script based on #114813 (comment) import torch
from torch.testing import make_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
from torch._inductor.compile_fx import compile_fx_inner, cudagraphify_impl
from torch._inductor.decomposition import decompositions
from itertools import product
from functools import partial
torch._logging.set_logs(output_code=True)
benchmark_name = "put"
Ss = [512]
def gen_inputs():
make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
make_source = partial(torch.randn, dtype=torch.float32, device="cuda")
def make_idx(n):
return make_tensor((n,), device="cuda", dtype=torch.int64, low=0, high=n)
for b, s, in product(Ss, Ss):
yield make_arg((b * s)), make_idx(b), make_source(b)
def benchmark(label, f, x, idx, source):
return Timer("f([x, idx, source])",
globals=locals(),
label=benchmark_name,
description=label,
sub_label=f"{tuple(x.shape)}",
num_threads=torch.get_num_threads()).blocked_autorange(min_run_time=2)
def compare(x, idx, source):
def f(args):
x, idx, source = args
val = torch.ops.aten.put(x, idx, source)
return (val,)
print(f"{tuple(x.shape)}")
args = [x, idx, source]
decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(args)
compiled_decomposed = compile_fx_inner(decomposed, args, cudagraphs=False)
yield benchmark("Decomposed", compiled_decomposed, *args)
non_decomposed = make_fx(f, tracing_mode="fake")(args)
compiled_nondecomposed = compile_fx_inner(non_decomposed, args, cudagraphs=False)
yield benchmark("Lowering", compiled_nondecomposed, *args)
# Just show the first two generated kernels
torch._logging.set_logs(output_code=False)
cuda_f = cudagraphify_impl(f, args, static_input_idxs=tuple(range(len(args))))
yield benchmark("Eager", cuda_f, *args)
results = []
for args in gen_inputs():
for res in compare(*args):
results.append(res)
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
|
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
torch/_inductor/lowering.py
Outdated
make_fallback(aten.polygamma) | ||
make_fallback(aten.put) | ||
make_fallback(aten.reflection_pad1d) | ||
make_fallback(aten.replication_pad1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your rebase has gone awry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to remove the aten.put
fallback since we're adding a decomposition that replaces it.
norm = x.norm(2, keep_dim, keepdim=True) | ||
return x * (y / norm), norm | ||
|
||
@register_decomposition(aten.put) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the test_has_decomposition
failure, you can run the test as
EXPECTTEST_ACCEPT=1 pytest test/test_decomp,py -k HasDecomp
and it will update the expected test output files
test/expect/HasDecompTest.test_has_decomposition.expect
test/expect/HasDecompTest.test_aten_core_opeartors.expect
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
As in the title. It is an updated copy of #115306 . Pull Request resolved: #120179 Approved by: https://github.com/lezcano, https://github.com/peterbell10, https://github.com/jgong5
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
landed in #120179 |
Getting a test failure here that I don't understand
However, from pure Python (CPU)
torch.put
and my decomposition seem to give the same thing:cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang