Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] _sparse_semi_structured_linear fallback - no meta registration; not on testing path #114477

Conversation

jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Nov 23, 2023

Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing fullgraph=True and missing meta registration.

When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped.

Is this something to do with the generated inductor code for:

 constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
...
slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None

and

[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
...
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None

Failing graphs:
Padded:

[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 5 =====
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.66 class GraphModule(torch.nn.Module):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2);  constant_pad_nd = primals_1 = primals_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2);  slice_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, le, permute]

Reshape:

[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.69 class GraphModule(torch.nn.Module):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3);  primals_1 = primals_2 = primals_3 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807);  _sparse_semi_structured_linear = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1);  view_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, view, le, primals_4, primals_5]

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Copy link

pytorch-bot bot commented Nov 23, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114477

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 682d353 with merge base 79ee99e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch pytorch deleted a comment from github-actions bot Nov 23, 2023
@jon-chuang jon-chuang changed the title [inductor] torch.compile + sparse_semi_structure fallback - no meta registration; not on testing path; gives wrong answers [inductor] _sparse_semi_structured_linear fallback - no meta registration; not on testing path; gives wrong answers Nov 23, 2023
@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 23, 2023

@jcaip your PR #111049 was never tested and doesn't give the right result.

Could you please look into fixing? You need to fix the (1, 128) and (64, 64, 128) cases.

@jcaip
Copy link
Contributor

jcaip commented Nov 27, 2023

Thanks for flagging this,

@alexsamardzic Would you happen to know if there are any special cases in the case of padded/reshaped inputs passed to _sparse_semi_structured_linear?

I don't see an issue for the (1, 128) and (64, 128, 128) test cases for _cslt_sparse_mm when I add in the meta registrations used here: #114370.

@alexsamardzic
Copy link
Collaborator

I confirm there is an issue there with this PR applied, and when CUTLASS backend used - the simpler reproducer is as follows:

Reproducer script
import torch

from torch.sparse.semi_structured import (
    SparseSemiStructuredTensor,
    to_sparse_semi_structured,
)

@torch.compile(backend="inductor", fullgraph=True)
def my_linear(input, weight):
    return torch.nn.functional.linear(input, weight)

SparseSemiStructuredTensor._FORCE_CUTLASS = True

m, n, k = 1, 32, 64
dtype = torch.half
device = "cuda"

torch.manual_seed(0)

input = torch.rand((m, k), dtype=dtype, device=device)
weight = torch.rand((n, k), dtype=dtype, device=device)

mask = torch.Tensor([1, 0, 0, 1]).to(dtype).to(device).tile((n, k // 4))

dense_weight = weight * mask
dense_result = torch.nn.functional.linear(input, dense_weight)

sparse_weight = to_sparse_semi_structured(dense_weight)
sparse_result = my_linear(input, sparse_weight)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

When I comment out @torch.compile line in the script above, the results for both dense and sparse case are the same, but with this line, these differ. However, I added printouts in the C++ code in _sparse_semi_structured_linear() function, that is calculating the linear operator, and in both cases inputs and outputs are the same. So after examining the generated code, I strongly suspect that the problem is in the line:

return (reinterpret_tensor(buf2, (1, 32), (32, 1), 0), )

that gets generated below aten._sparse_semi_structured_linear() call. So I'm not sure where the fix is to be made?

(Above was all for the case of (1, 64) input. For the other problematic input size, with batch dimension, it seems to be the same kind of problem, however in this case not because of padding, but instead because the CUTLASS version is internally squashing batch dimensions of input tensor together with row dimension - but it is also internally properly setting back the dimensions of the result vector, so the tensor produced by _sparse_semi_structured_linear() is of correct sizes, and reinterpret_tensor call again should not be needed.)

…uang/sparse-structured-compile-not-tested
@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

Generated code for this PR

@pointwise(
    size_hints=[2048], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': []},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 64)
    x0 = xindex % 64
    x2 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 1, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tl.store(out_ptr0 + (x2), tmp5, None)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 64), (64, 1))
    assert_size_stride(arg1_1, (32, 32), (32, 1))
    assert_size_stride(arg2_1, (32, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty((32, 64), device='cuda', dtype=torch.float16)
        # Source Nodes: [linear], Original ATen: [aten.constant_pad_nd]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_constant_pad_nd_0.run(arg0_1, buf0, 2048, grid=grid(2048), stream=stream0)
        del arg0_1
        # Source Nodes: [linear], Original ATen: [aten._sparse_semi_structured_linear, aten.constant_pad_nd]
        buf1 = aten._sparse_semi_structured_linear(buf0, arg1_1, arg2_1)
        del arg1_1
        del arg2_1
        del buf0
        buf2 = buf1
        return (reinterpret_tensor(buf2, (1, 32), (32, 1), 0), )

Generated code for main:

@pointwise(
    size_hints=[2048], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_0', 'mutated_arg_names': []},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 64)
    x0 = xindex % 64
    x2 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 1, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tl.store(out_ptr0 + (x2), tmp5, None)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 64), (64, 1))
    assert_size_stride(arg1_1, (32, 32), (32, 1))
    assert_size_stride(arg2_1, (32, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty((32, 64), device='cuda', dtype=torch.float16)
        # Source Nodes: [linear], Original ATen: [aten.constant_pad_nd]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_constant_pad_nd_0.run(arg0_1, buf0, 2048, grid=grid(2048), stream=stream0)
        del arg0_1
        # Source Nodes: [linear], Original ATen: [aten._sparse_semi_structured_linear, aten.constant_pad_nd]
        buf1 = aten._sparse_semi_structured_linear(buf0, arg1_1, arg2_1)
        del arg1_1
        del arg2_1
        del buf0
        buf2 = buf1
        return (reinterpret_tensor(buf2, (1, 32), (1, 32), 0), )

So it seems like yes, it's the stride information.

- return (reinterpret_tensor(buf2, (1, 32), (32, 1), 0), )
+ return (reinterpret_tensor(buf2, (1, 32), (1, 32), 0), )

Perhaps I did not capture this in the meta_registration somehow?

@jon-chuang
Copy link
Collaborator Author

I was able to fix the repro with a change to the meta registration, adding a transpose.

    return (
        input.new_empty(
            output_sizes,
            dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
        )
        .transpose(-1, -2)
        .reshape(output_sizes)
    )

But the original tests still fail

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Nov 28, 2023

I'm still not particularly familiar with Inductor, could you quickly explain what is the purpose of this method, i.e. how its result is used? Edit: Is it just there to calculate the output tensor shape, according to the inputs?

@jon-chuang
Copy link
Collaborator Author

Edit: Is it just there to calculate the output tensor shape, according to the inputs?

Yes, and other metadata like strides, hence the name meta registration

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

Actually, this seems to have something to do with dynamic shapes.

This strides are not correctly updated for dynamic shapes even after a transpose!

INPUT STRIDE (128, 1)
OUTPUT STRIDE (s0*((128//s0)), 1)

@alexsamardzic
Copy link
Collaborator

It seems to me that weight input there is not of correct shape: for the test script I provided above, the shape of weight tensor is (32, 64), and if I print it from within meta_sparse_structured_linear() method, it says (32, 32).

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

@alexsamardzic that's expected, as the K dimension for the weight matrix should be 2-4 sparse.

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Nov 28, 2023

OK. Let me quickly explain how CUTLASS backend calculates this, maybe it will help. So the operands for the linear operator are tensors input of size (*b, m, k) (where *b stands for batch size(s)) and weight of size (n, k); the operator calculates input @ weight.T, that is of size (*b, m, n) - that is all as expected. But, the CUTLASS backend is unusual because CUTLASS doesn't support sparse matrix as second operand, but only as first, so the result is actually calculated as (weigth @ input.T).T. The _sparse_semi_structured_linear() function, in C++ code, accepts operands as described, and pefrorms the calculation, also as described. But, before actually doing matrix multiplication, there is also squashing of batch dimensions of input tensor into its row dimension (as CUTLASS accepts 2D tensors only here), and after the last transpose the proper shape of the output tensor is set through a reshape() call.

So, maybe taking together how the operator works, regarding the presence of batch dimension(s), with padding in the (1, K) input case, is the cause of issues.

Edit: part of my point here is that the output in non-compiled case, may be non-contiguous.

Edit 2: Yep, adding .contiguous() at the end of the line 705 in aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu makes all the problems disappear - so it's all about strides here, but for performance reasons I don't think we want to actually fix it this way.

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

@alexsamardzic this code solves the 1, 32, 64 case but doesn't solve the 1, 128, 128 case. It also seems to avoid the dynamic shapes issue. Based on what you understand about how the strides should be, can you spot any mistakes?

@register_meta(aten._sparse_semi_structured_linear)
def meta_sparse_structured_linear(
    input: Tensor,
    weight: Tensor,
    _meta: Tensor,
    bias: Optional[Tensor] = None,
    _activation_opt: Optional[str] = None,
):
    output_sizes = list(input.shape)
    if bias is not None:
        assert weight.size(0) == bias.size(0), "output size mismatch"
    assert weight.size(1) == input.size(-1) / 2
    output_sizes[-1] = weight.size(0)

    transposed_strides = input.new_empty(output_sizes).transpose(-1, -2).stride()

    return input.new_empty(
        output_sizes,
        dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
    ).as_strided(output_sizes, transposed_strides)

What I am seeing is that this is getting the strides wrong by a factor of 2.

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

I've included the suggested changes, but would appreciate an explanation. Will include the explanation as inline comments in the meta registration.

@alexsamardzic
Copy link
Collaborator

As usual - it appears dead simple now: Apparently, for compiled case something is doing squashing of batch dimensions before passing input to C++ code, and un-squashing afterwards. So the C++ code always gets 2D input tensor, and it will always produce contiguous output, that just gets transposed at the end. So the size of the output will be (input.size(0), weight.size(0)), but as transpose() is called at the end of C++ code, the strides will be exactly as above.

Should work fine for now, but it's kind of fragile: it will break if someone changes this squashing/un-squashing logic down the road, or maybe adds padding for other dimensions. But, the test case should catch it then.

@jcaip
Copy link
Contributor

jcaip commented Nov 28, 2023

@alexsamardzic Thanks for the help, I forgot about the noncontiguous output. Yes, I added in 2d flattening for cusparselt, and CUTLASS uses the same code path

@jon-chuang This works because the output of _sparse_semi_structured_linear is non-contiguous as it's really the output of (weigth @ input.T).T. When you registered the meta registration initially, you were returning a contiguous tensor, created by new_dense.

Now this doesn't error in the case when we don't have padding, as although the stride information is not set properly, allclose doesn't care about this and just checks that all the values are close, which they are. (this is still a silent error, imo) However, in the case when we pad the dense matrix, we additionally need to select just the non-padded values out of the resultant matrix, which we do here. From what I can understand this is what's reinterpret_tensor is doing, but since the strides are set differently for the tensor returned from the meta_registration and the output of _sparse_semi_structured are different we "select" the wrong line and that's why the test fails.

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Nov 28, 2023

Ok, I will stick to asserting on the squashed case. It is complicated and error prone to handle the strides for the unsquashed case.

@jon-chuang jon-chuang changed the title [inductor] _sparse_semi_structured_linear fallback - no meta registration; not on testing path; gives wrong answers [inductor] _sparse_semi_structured_linear fallback - no meta registration; not on testing path Nov 28, 2023
@jon-chuang
Copy link
Collaborator Author

@jcaip ready for final review

Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you!

@alexsamardzic
Copy link
Collaborator

@alexsamardzic Would this be correct even in the non-squashed case:

    transposed_strides = (1, reduce(operator.mul, input.size()[:-1], 1))

Not exactly, there are more strides in this case - it's not complicated to calculate them, but we'd have to differentiate between batched and non-batched case, etc. So as we can't test it at the moment, let's indeed put an assert there.

In any case, thank you for this work! This debugging session was actually very useful for me, namely I'm at the moment working on adding some sparse MM related features to CUTLASS, and at some point afterwards I hope to add this same thing for the CUTLASS backend of Inductor - so I would hit this stuff anyway.

jcaip added a commit that referenced this pull request Nov 28, 2023
_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work @drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Nov 28, 2023
_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e2665c84b666464ee8f91ed8f1055d62a8e35799
Pull Request resolved: #114685
@jon-chuang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 28, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jcaip added a commit that referenced this pull request Nov 28, 2023
… _cslt_sparse_mm + additional stride checking in test."

_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Nov 28, 2023
…se_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 91a64316a3399c2cb7082403547ee92f02f746a2
Pull Request resolved: #114685
jcaip added a commit that referenced this pull request Nov 28, 2023
…egistrations for _cslt_sparse_mm + additional stride checking in test."

_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Nov 28, 2023
… _cslt_sparse_mm + additional stride checking in test."

_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Nov 28, 2023
…se_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cc03188401100f6f1268844b99d1239d498dcead
Pull Request resolved: #114685
pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2023
…se_mm + additional stride checking in test. (#114685)

_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work @drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: #114685
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
vfdev-5 pushed a commit to vfdev-5/pytorch that referenced this pull request Nov 29, 2023
…ration; not on testing path (pytorch#114477)

Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration.

When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped.

Is this something to do with the generated inductor code for:
```
 constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
...
slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
```
and

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
...
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
```

Failing graphs:
Padded:
```
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 5 =====
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.66 class GraphModule(torch.nn.Module):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2);  constant_pad_nd = primals_1 = primals_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2);  slice_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, le, permute]

```

Reshape:

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.69 class GraphModule(torch.nn.Module):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3);  primals_1 = primals_2 = primals_3 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807);  _sparse_semi_structured_linear = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1);  view_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, view, le, primals_4, primals_5]

```

Pull Request resolved: pytorch#114477
Approved by: https://github.com/jcaip
vfdev-5 pushed a commit to vfdev-5/pytorch that referenced this pull request Nov 29, 2023
…se_mm + additional stride checking in test. (pytorch#114685)

_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work @drisspg did
in pytorch#114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

pytorch#114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#114685
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants