New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] _sparse_semi_structured_linear
fallback - no meta registration; not on testing path
#114477
[inductor] _sparse_semi_structured_linear
fallback - no meta registration; not on testing path
#114477
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114477
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 682d353 with merge base 79ee99e (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
_sparse_semi_structured_linear
fallback - no meta registration; not on testing path; gives wrong answers
Thanks for flagging this, @alexsamardzic Would you happen to know if there are any special cases in the case of padded/reshaped inputs passed to _sparse_semi_structured_linear? I don't see an issue for the (1, 128) and (64, 128, 128) test cases for _cslt_sparse_mm when I add in the meta registrations used here: #114370. |
I confirm there is an issue there with this PR applied, and when CUTLASS backend used - the simpler reproducer is as follows: Reproducer scriptimport torch
from torch.sparse.semi_structured import (
SparseSemiStructuredTensor,
to_sparse_semi_structured,
)
@torch.compile(backend="inductor", fullgraph=True)
def my_linear(input, weight):
return torch.nn.functional.linear(input, weight)
SparseSemiStructuredTensor._FORCE_CUTLASS = True
m, n, k = 1, 32, 64
dtype = torch.half
device = "cuda"
torch.manual_seed(0)
input = torch.rand((m, k), dtype=dtype, device=device)
weight = torch.rand((n, k), dtype=dtype, device=device)
mask = torch.Tensor([1, 0, 0, 1]).to(dtype).to(device).tile((n, k // 4))
dense_weight = weight * mask
dense_result = torch.nn.functional.linear(input, dense_weight)
sparse_weight = to_sparse_semi_structured(dense_weight)
sparse_result = my_linear(input, sparse_weight)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) When I comment out
that gets generated below (Above was all for the case of |
…uang/sparse-structured-compile-not-tested
Generated code for this PR @pointwise(
size_hints=[2048],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2048
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 64)
x0 = xindex % 64
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 1, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x0), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tl.store(out_ptr0 + (x2), tmp5, None)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 64), (64, 1))
assert_size_stride(arg1_1, (32, 32), (32, 1))
assert_size_stride(arg2_1, (32, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((32, 64), device='cuda', dtype=torch.float16)
# Source Nodes: [linear], Original ATen: [aten.constant_pad_nd]
stream0 = get_cuda_stream(0)
triton_poi_fused_constant_pad_nd_0.run(arg0_1, buf0, 2048, grid=grid(2048), stream=stream0)
del arg0_1
# Source Nodes: [linear], Original ATen: [aten._sparse_semi_structured_linear, aten.constant_pad_nd]
buf1 = aten._sparse_semi_structured_linear(buf0, arg1_1, arg2_1)
del arg1_1
del arg2_1
del buf0
buf2 = buf1
return (reinterpret_tensor(buf2, (1, 32), (32, 1), 0), ) Generated code for main: @pointwise(
size_hints=[2048],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2048
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 64)
x0 = xindex % 64
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 1, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x0), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tl.store(out_ptr0 + (x2), tmp5, None)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 64), (64, 1))
assert_size_stride(arg1_1, (32, 32), (32, 1))
assert_size_stride(arg2_1, (32, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((32, 64), device='cuda', dtype=torch.float16)
# Source Nodes: [linear], Original ATen: [aten.constant_pad_nd]
stream0 = get_cuda_stream(0)
triton_poi_fused_constant_pad_nd_0.run(arg0_1, buf0, 2048, grid=grid(2048), stream=stream0)
del arg0_1
# Source Nodes: [linear], Original ATen: [aten._sparse_semi_structured_linear, aten.constant_pad_nd]
buf1 = aten._sparse_semi_structured_linear(buf0, arg1_1, arg2_1)
del arg1_1
del arg2_1
del buf0
buf2 = buf1
return (reinterpret_tensor(buf2, (1, 32), (1, 32), 0), ) So it seems like yes, it's the stride information. - return (reinterpret_tensor(buf2, (1, 32), (32, 1), 0), )
+ return (reinterpret_tensor(buf2, (1, 32), (1, 32), 0), ) Perhaps I did not capture this in the |
I was able to fix the repro with a change to the meta registration, adding a transpose. return (
input.new_empty(
output_sizes,
dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
)
.transpose(-1, -2)
.reshape(output_sizes)
) But the original tests still fail |
I'm still not particularly familiar with Inductor, could you quickly explain what is the purpose of this method, i.e. how its result is used? Edit: Is it just there to calculate the output tensor shape, according to the inputs? |
Yes, and other metadata like strides, hence the name meta registration |
Actually, this seems to have something to do with dynamic shapes. This strides are not correctly updated for dynamic shapes even after a transpose!
|
It seems to me that |
@alexsamardzic that's expected, as the K dimension for the weight matrix should be 2-4 sparse. |
OK. Let me quickly explain how CUTLASS backend calculates this, maybe it will help. So the operands for the linear operator are tensors So, maybe taking together how the operator works, regarding the presence of batch dimension(s), with padding in the Edit: part of my point here is that the output in non-compiled case, may be non-contiguous. Edit 2: Yep, adding |
@alexsamardzic this code solves the 1, 32, 64 case but doesn't solve the 1, 128, 128 case. It also seems to avoid the dynamic shapes issue. Based on what you understand about how the strides should be, can you spot any mistakes? @register_meta(aten._sparse_semi_structured_linear)
def meta_sparse_structured_linear(
input: Tensor,
weight: Tensor,
_meta: Tensor,
bias: Optional[Tensor] = None,
_activation_opt: Optional[str] = None,
):
output_sizes = list(input.shape)
if bias is not None:
assert weight.size(0) == bias.size(0), "output size mismatch"
assert weight.size(1) == input.size(-1) / 2
output_sizes[-1] = weight.size(0)
transposed_strides = input.new_empty(output_sizes).transpose(-1, -2).stride()
return input.new_empty(
output_sizes,
dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
).as_strided(output_sizes, transposed_strides) What I am seeing is that this is getting the strides wrong by a factor of 2. |
I've included the suggested changes, but would appreciate an explanation. Will include the explanation as inline comments in the meta registration. |
As usual - it appears dead simple now: Apparently, for compiled case something is doing squashing of batch dimensions before passing Should work fine for now, but it's kind of fragile: it will break if someone changes this squashing/un-squashing logic down the road, or maybe adds padding for other dimensions. But, the test case should catch it then. |
@alexsamardzic Thanks for the help, I forgot about the noncontiguous output. Yes, I added in 2d flattening for cusparselt, and CUTLASS uses the same code path @jon-chuang This works because the output of _sparse_semi_structured_linear is non-contiguous as it's really the output of Now this doesn't error in the case when we don't have padding, as although the stride information is not set properly, allclose doesn't care about this and just checks that all the values are close, which they are. (this is still a silent error, imo) However, in the case when we pad the dense matrix, we additionally need to select just the non-padded values out of the resultant matrix, which we do here. From what I can understand this is what's reinterpret_tensor is doing, but since the strides are set differently for the tensor returned from the meta_registration and the output of _sparse_semi_structured are different we "select" the wrong line and that's why the test fails. |
Ok, I will stick to asserting on the squashed case. It is complicated and error prone to handle the strides for the unsquashed case. |
_sparse_semi_structured_linear
fallback - no meta registration; not on testing path; gives wrong answers_sparse_semi_structured_linear
fallback - no meta registration; not on testing path
@jcaip ready for final review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you!
Not exactly, there are more strides in this case - it's not complicated to calculate them, but we'd have to differentiate between batched and non-batched case, etc. So as we can't test it at the moment, let's indeed put an assert there. In any case, thank you for this work! This debugging session was actually very useful for me, namely I'm at the moment working on adding some sparse MM related features to CUTLASS, and at some point afterwards I hope to add this same thing for the CUTLASS backend of Inductor - so I would hit this stuff anyway. |
_cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work @drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
_cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e2665c84b666464ee8f91ed8f1055d62a8e35799 Pull Request resolved: #114685
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… _cslt_sparse_mm + additional stride checking in test." _cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…se_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 91a64316a3399c2cb7082403547ee92f02f746a2 Pull Request resolved: #114685
…egistrations for _cslt_sparse_mm + additional stride checking in test." _cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
… _cslt_sparse_mm + additional stride checking in test." _cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…se_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: cc03188401100f6f1268844b99d1239d498dcead Pull Request resolved: #114685
…se_mm + additional stride checking in test. (#114685) _cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work @drisspg did in #114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in #114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: #114685 Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
…ration; not on testing path (pytorch#114477) Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration. When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped. Is this something to do with the generated inductor code for: ``` constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) ... slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None ``` and ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None ... [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None ``` Failing graphs: Padded: ``` [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 5 ===== [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.66 class GraphModule(torch.nn.Module): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2); constant_pad_nd = primals_1 = primals_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2); slice_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]); primals_3 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, le, permute] ``` Reshape: ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.69 class GraphModule(torch.nn.Module): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3); primals_1 = primals_2 = primals_3 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807); _sparse_semi_structured_linear = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1); view_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, view, le, primals_4, primals_5] ``` Pull Request resolved: pytorch#114477 Approved by: https://github.com/jcaip
…se_mm + additional stride checking in test. (pytorch#114685) _cslt_sparse_mm + additional stride checking in test. Summary: This PR adds in meta registrations for _cslt_sparse_mm. Based on the work @drisspg did in pytorch#114370. Additionally, it updates the tests by checking that the strides of the spare result and the result returned by sparse+compile are the same, to avoid errors like those found in pytorch#114477. Test Plan: ``` python test/test_sparse_semi_structred -k compile_cusparselt python test/test_sparse_semi_structred -k compile_cutlass ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#114685 Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing
fullgraph=True
and missing meta registration.When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped.
Is this something to do with the generated inductor code for:
and
Failing graphs:
Padded:
Reshape:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler