Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port cat kernel to structured kernels. #68640

Closed
wants to merge 29 commits into from

Conversation

Tracking issue: #55070

[ghstack-poisoned]
@pytorch-probot
Copy link

pytorch-probot bot commented Nov 19, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/94aec9cc15c1aeb35d1a549dab4913140c7fefaa/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 19, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 95b2799 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-04-15T20:33:43.4836391Z The PR is introduc...m to confirm whether this change is wanted or not.
2022-04-15T20:33:43.4822712Z processing existing schema:  text(__torch__.torch.classes.profiling.SourceRef _0) -> (str _0)
2022-04-15T20:33:43.4823566Z processing existing schema:  count(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-04-15T20:33:43.4824818Z processing existing schema:  duration_ns(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-04-15T20:33:43.4825925Z processing existing schema:  source(__torch__.torch.classes.profiling.SourceStats _0) -> (__torch__.torch.classes.profiling.SourceRef _0)
2022-04-15T20:33:43.4827510Z processing existing schema:  line_map(__torch__.torch.classes.profiling.SourceStats _0) -> (Dict(int, __torch__.torch.classes.profiling.InstructionStats) _0)
2022-04-15T20:33:43.4828866Z processing existing schema:  __init__(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-04-15T20:33:43.4829778Z processing existing schema:  enable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-04-15T20:33:43.4830917Z processing existing schema:  disable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-04-15T20:33:43.4832453Z processing existing schema:  _dump_stats(__torch__.torch.classes.profiling._ScriptProfile _0) -> (__torch__.torch.classes.profiling.SourceStats[] _0)
2022-04-15T20:33:43.4834453Z processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (NoneType _0)
2022-04-15T20:33:43.4836391Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-04-15T20:33:43.4836424Z 
2022-04-15T20:33:43.4836506Z Broken ops: [
2022-04-15T20:33:43.4836771Z 	aten::as_strided_copy(Tensor self, int[] size, int[] stride, int? storage_offset=None) -> (Tensor)
2022-04-15T20:33:43.4836922Z 	aten::_values_copy(Tensor self) -> (Tensor)
2022-04-15T20:33:43.4837062Z 	aten::alias_copy(Tensor self) -> (Tensor)
2022-04-15T20:33:43.4837342Z 	aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> (Tensor)
2022-04-15T20:33:43.4837499Z 	aten::_fw_primal_copy(Tensor self, int level) -> (Tensor)
2022-04-15T20:33:43.4837701Z 	aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> (Tensor)
2022-04-15T20:33:43.4837851Z 	aten::view_as_real_copy(Tensor self) -> (Tensor)
2022-04-15T20:33:43.4838011Z 	aten::view_as_complex_copy(Tensor self) -> (Tensor)

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Nov 19, 2021
ysiraichi added a commit that referenced this pull request Nov 19, 2021
Tracking issue: #55070

ghstack-source-id: eb657eaa09ae8ea2ee2ae4c84f7c0a7589efe02a
Pull Request resolved: #68640
ysiraichi added a commit that referenced this pull request Nov 19, 2021
Tracking issue: #55070

ghstack-source-id: 96b892bc36ddf4b82b72af1fcb1496f4d56276c1
Pull Request resolved: #68640
bool all_contiguous = true;
bool all_same_dtype = true;
bool all_same_sizes_and_stride = true;
auto memory_format = cat_compute_output_memory_format(tensors);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU and CUDA kernels computed the memory format in different ways. Here, I adopted how CUDA used to do it. Not sure if this is the best way to go at it. Any thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ngimel for another opinion. I found the PR that added support for cuda, and a bit later for cpu. There's no explicit mention of why they differ, so, going with the way cuda does it now seems reasonable to me.

Copy link
Contributor

@bdhirsh bdhirsh Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry helped point me to some discussion about this: #62560 (comment).

It sounds like making cuda's behavior the general behavior is the right move in this PR. We should also add a test (if one doesn't exist already), that confirms that memory_format behavior is the same across cpu and cuda.

@mruberry also pointed out another thing we need to fix in this PR: see #64709. Apparently, there parts of the codebase that call cat and expect cat(out=...) to resize the output tensor. That'll start printing warnings now that cat is structured. We should fix all of the places in our codebase that call cat(out=...) and expect the resize to happen, that way users don't start seeing warnings that they can't fix.

The easiest way to do that is probably to copy conditions in resize_output() that create the warn, and run them directly in the cat meta function, but raise an error instead of a warning. Then fix all the errors that show up from CI.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing cat memory_format test is here:

def test_cat_out_memory_format(self, device):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ysiraichi, indeed the CUDA and CPU memory format calculation is inconsistent. I opened issue #63998 (talking exactly about this) earlier and wanted to work on a fix for that but never got around that because of some new responsibilities until earlier this week. I just came across this PR by chance and I see you have already incorporated the change.

So, with this PR, #63998 should also be fixed.

@bdhirsh when adding the existing cat memory format test, I kept the different behaviour of memory formats in mind. It will now need slight modification with your update making CPU and CUDA consistent, which should be quite straightforward (Edit: Yukio already covered that!).

@@ -626,16 +626,22 @@ def test_cat_out(self, device):
y = torch.randn((4, 6), device=device)

with self.assertRaisesRegex(
RuntimeError, r"unsupported operation:.* input tensor 0"):
RuntimeError,
r"unsupported operation: some elements of the input tensor and "
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error messages changed here, because I'm using at::assert_no_overlap function. Should I revert that to show better error messages?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable. The previous error message was also related to overlap:

>>> torch.cat([x, y], out=x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: 0unsupported operation: the input tensors cannot refer to any of the output memory locations. Found overlap in input tensor 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I was a bit worried that not having the number of the input tensor that had an overlap would be bad. :)

@@ -751,7 +756,7 @@ def test_cat_out_memory_format(self, device):
res2_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu)

self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.contiguous_format))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I'm using only one method for inferring the memory format, CPU and CUDA behave consistently.

@ysiraichi ysiraichi added the module: structured kernels Related to new structured kernels functionality label Nov 19, 2021
ysiraichi added a commit that referenced this pull request Nov 24, 2021
Tracking issue: #55070

ghstack-source-id: 672014c0dd712eb9ff82fac784f2079321fcf39f
Pull Request resolved: #68640
ysiraichi added a commit that referenced this pull request Nov 24, 2021
Tracking issue: #55070

ghstack-source-id: 96bfc16bcb05bfd99eaf0eba1ce1a28265a94802
Pull Request resolved: #68640
ysiraichi added a commit that referenced this pull request Nov 25, 2021
Tracking issue: #55070

ghstack-source-id: 29a4e1af550449b1b182a73cfea37218684e7e22
Pull Request resolved: #68640
@ysiraichi ysiraichi marked this pull request as ready for review November 26, 2021 09:44
CPU: _cat_out_cpu
CUDA: cat_out_cuda
QuantizedCPU: cat_out_quantized_cpu

Copy link
Contributor

@bdhirsh bdhirsh Nov 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 do you happen to know the history behind why we had both an at::cat and at::_cat? I couldn't dig up much of a reason from git blame. Although it seems useful to try to kill it as part of this structured kernel port.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we have both. I agree if we can kill _cat we should; there doesn't seem to be a need for both _cat and cat to exist

dispatch:
CompositeExplicitAutograd: cat
SparseCPU, SparseCUDA: cat_sparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaning up the sparse logic :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for consistency we'd like to have an out= variant for sparse kernels, but it looks like that's not true today. It's probably not necessary to try to fix that in this PR.

@@ -739,8 +745,7 @@ def test_cat_out_memory_format(self, device):
self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))

# Case 2: if out= is not the correct shape then the output it is resized internally
# - For the CPU variant the memory format is that of the first tensor
# - For the CUDA variant it only propagates memory format if all the tensors have
# - For both CPU and CUDA variants, it only propagates memory format if all the tensors have
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel, is this change in memory format propagation on CPU an issue? It seems useful to clean this logic up so it's backend-agnostic, but it sounds mildly BC-breaking.

@ezyang
Copy link
Contributor

ezyang commented Mar 8, 2022

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

ysiraichi added a commit to ysiraichi/pytorch that referenced this pull request Mar 10, 2022
Tracking issue: pytorch#55070

ghstack-source-id: cc4bb13cb10c842b31b3ff5e3be2b5360e927d39
Pull Request resolved: pytorch#68640
@ysiraichi
Copy link
Collaborator Author

@ngimel @ezyang @bdhirsh
I ran some benchmarks on different branches in this stack of PRs:

  • bc512253d5: the baseline
  • #69607: similar to this PR, but generalizes ITensorListRef to IListRef
  • #73351: uses IListRef in the dispatcher, eliminating wasteful copies and ref-count bumps
Size Branch Time CPU (us) Time CUDA (us) Instructions
[1,2] bc512253d5
#69607
#73351
97.01 (2.28)
105.14 (3.65)
93.15 (4.09)
356.75 (70.23)
330.70 (6.69)
326.67 (3.44)
68323495.50 (12375.54)
74569195.40 (12037.75)
75222670.40 (20938.94)
[1,2,2] bc512253d5
#69607
#73351
102.01 (1.11)
106.51 (4.32)
94.33 (0.86)
347.10 (18.40)
363.83 (66.13)
361.96 (73.93)
71919045.40 (2468.85)
78584658.80 (17226.44)
79226327.90 (20665.40)
[3,256,256] bc512253d5
#69607
#73351
224157.47 (27083.82)
239175.47 (36902.33)
235462.35 (34238.07)
2970.30 (6.33)
2967.31 (8.96)
2968.43 (1.98)
12363011366.90 (2486.03)
12363671642.80 (3776.23)
12356348936.20 (2990.37)
Benchmark

ntensors = 1000

sizes = [
    (1, 2),
    (1, 2, 2),
    (3, 256, 256),
]

results = []
results_callgrind = []

for size in sizes:
    size_str = f"""[{",".join(str(s) for s in size)}]"""
    timer = benchmark.Timer(
        stmt="torch.cat(xs)",
        setup=f"import torch; xs = [torch.rand(*{size}) for _ in range({ntensors})];",
        label="Cat",
        sub_label=size_str,
        description="time (cpu)"
    )

    timer_cuda = benchmark.Timer(
        stmt="torch.cat(xs); torch.cuda.synchronize()",
        setup=f"import torch; xs = [torch.rand(*{size}, device='cuda') for _ in range({ntensors})]; torch.cuda.synchronize()",
        label="Cat",
        sub_label=size_str,
        description="time (cuda)"
    )

    results.append(timer.blocked_autorange(min_run_time=1))
    results.append(timer_cuda.blocked_autorange(min_run_time=1))
    results_callgrind.append(timer.collect_callgrind())

compare = benchmark.Compare(results)
compare.print()

for r in results_callgrind:
    print(r)

@ezyang
Copy link
Contributor

ezyang commented Apr 12, 2022

I would love to start merging this stack but there are still build failures

/var/lib/jenkins/workspace/aten/src/ATen/native/cuda/Shape.cu:17:10: fatal error: ATen/ops/_cat_native.h: No such file or directory
 #include <ATen/ops/_cat_native.h>
          ^~~~~~~~~~~~~~~~~~~~~~~~

@ezyang
Copy link
Contributor

ezyang commented Apr 12, 2022

Since you've got the benchmark, I wonder if you can run a little experiment, which is to try NOT materializing in the body of the kernel, and being willing to iterate multiple times. It may be that the dynamic allocation is swamping the predictable branches, and so we'd rather pump up the instruction count and avoid the dynamic alloc. Would be nice to know one way or another.

@ysiraichi
Copy link
Collaborator Author

I'm working on the build failures.

I wonder if you can run a little experiment, which is to try NOT materializing in the body of the kernel

Sure! I will do that once CI is green.

@ezyang
Copy link
Contributor

ezyang commented Apr 14, 2022

@pytorchbot merge this

@github-actions
Copy link

Hey @ysiraichi.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@ysiraichi ysiraichi added topic: not user facing topic category release notes: cpp release notes category labels Apr 15, 2022
@ezyang ezyang reopened this Apr 15, 2022
@ezyang
Copy link
Contributor

ezyang commented Apr 15, 2022

this is about to get yanked from the diff train because it breaks internal users. I just need to update call sites to not use native::

facebook-github-bot pushed a commit that referenced this pull request Apr 19, 2022
Summary:
Tracking issue: #55070

Pull Request resolved: #68640

Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/22a10ce51310e690745accce0910d740b82a1503

Reviewed By: dagitses

Differential Revision: D34521686

Pulled By: mehtanirav

fbshipit-source-id: 58434529551e9e09939f6e28367316a6a20d7774
@mehtanirav
Copy link
Contributor

Fixed the internal references and few other linter errors in internal diff stack

@mehtanirav mehtanirav closed this Apr 19, 2022
@ysiraichi
Copy link
Collaborator Author

@ezyang
I tried benchmarking again, as you suggested.
Here are some things you should know before reading the table:

  • The baseline branch was updated bc512253d5 -> 9b639f263d
  • Ran 10 times, took the average and standard deviation (in parenthesis)
  • Not sure why the timings for CUDA are so different from before
Size Branch Time CPU (us) Time CUDA (us) Instructions
[1,2] 9b639f263d
#69607 No Materialize
#69607
108.26 (6.94)
117.22 (5.04)
112.24 (19.60)
583.42 (15.39)
379.66 (13.22)
425.72 (8.77)
70205164.80 (3087.50)
78516968.20 (4285.81)
74005639.00 (2517.92)
[1,2,2] 9b639f263d
#69607 No Materialize
#69607
112.11 (11.83)
127.76 (8.92)
111.87 (8.63)
656.60 (39.17)
398.07 (19.09)
423.01 (9.66)
73810091.60 (2313.66)
82418765.00 (4782.92)
78007370.90 (8175.53)
[3,256,256] 9b639f263d
#69607 No Materialize
#69607
231682.92 (96260.47)
212985.95 (52979.87)
223969.03 (45350.19)
3276.22 (52.04)
3054.36 (35.59)
3075.21 (20.65)
21130733648.90 (6350595776.36)
22370734623.30 (5351717520.03)
31795904828.90 (1386977188.98)

mikeiovine pushed a commit that referenced this pull request Apr 20, 2022
#68640 broke our build by porting `cat` structured kernels, not sure how CI didn't catch this

Differential Revision: [D35780296](https://our.internmc.facebook.com/intern/diff/D35780296/)

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Apr 20, 2022

I'm having a little difficulty interpreting the numbers here. What does the parenthesized number mean? If I go only by the non-parenthesized one, it seems like no materialize is better despite higher instruction count (what I suspected) and we should use that.

facebook-github-bot pushed a commit that referenced this pull request Apr 20, 2022
Summary:
Pull Request resolved: #76111

#68640 broke our build by porting `cat` structured kernels, not sure how CI didn't catch this
ghstack-source-id: 154335722

Test Plan: CI

Reviewed By: navahgar, ajyu

Differential Revision: D35780296

fbshipit-source-id: 0a262eb06a8d619227e5db10b6a775bf0b2e17c1
pytorchmergebot pushed a commit that referenced this pull request Apr 20, 2022
Summary:
Pull Request resolved: #76111

#68640 broke our build by porting `cat` structured kernels, not sure how CI didn't catch this
ghstack-source-id: 154335722

Test Plan: CI

Reviewed By: navahgar, ajyu

Differential Revision: D35780296

fbshipit-source-id: 0a262eb06a8d619227e5db10b6a775bf0b2e17c1
(cherry picked from commit aea6fbf)
@ysiraichi
Copy link
Collaborator Author

What does the parenthesized number mean?

It's the standard deviation. I left it there just to give a sense of how much the runs varied.

it seems like no materialize is better despite higher instruction count (what I suspected) and we should use that.

Agreed.

@facebook-github-bot facebook-github-bot deleted the gh/ysiraichi/37/head branch April 23, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: structured kernels Related to new structured kernels functionality oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: cpp release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet