From cef79c0df4a340e812efd7fd8534a7fe3ba0109b Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Tue, 28 Nov 2023 19:35:01 +0000 Subject: [PATCH] [inductor] `_sparse_semi_structured_linear` fallback - no meta registration; not on testing path (#114477) Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration. When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped. Is this something to do with the generated inductor code for: ``` constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) ... slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None ``` and ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None ... [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None ``` Failing graphs: Padded: ``` [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 5 ===== [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] .66 class GraphModule(torch.nn.Module): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2); constant_pad_nd = primals_1 = primals_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2); slice_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]); primals_3 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, le, permute] ``` Reshape: ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] .69 class GraphModule(torch.nn.Module): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3); primals_1 = primals_2 = primals_3 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807); _sparse_semi_structured_linear = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1); view_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, view, le, primals_4, primals_5] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114477 Approved by: https://github.com/jcaip --- test/test_sparse_semi_structured.py | 8 ++++---- torch/_meta_registrations.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 31d294f248786..9af77eab605f9 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -160,14 +160,14 @@ def _test_mlp_contiguous_relu_compile(backend, dense_input_shape): dense_result = model(input) mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight)) - model = torch.compile(model) + model = torch.compile(model, backend="inductor", fullgraph=True) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: - _test_mlp_contiguous_relu_compile(backend, dense_input_shape) + for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS: + for dense_input_shape in [(128, 128), (64, 128), (1, 128), (64, 128, 128)]: + _test_mlp_contiguous_relu_compile(backend, dense_input_shape) class TestSparseSemiStructured(TestCase): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index bb10a34c4c064..112b641ca84f1 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -394,6 +394,35 @@ def meta_unsqueeze_(self, dim): return self +@register_meta(aten._sparse_semi_structured_linear) +def meta_sparse_structured_linear( + input: Tensor, + weight: Tensor, + _meta: Tensor, + bias: Optional[Tensor] = None, + _activation_opt: Optional[str] = None, +): + output_sizes = list(input.shape) + if bias is not None: + assert weight.size(0) == bias.size(0), "output size mismatch" + assert weight.size(1) == input.size(-1) / 2 + output_sizes[-1] = weight.size(0) + + # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375 + # We assume that we have already squashed the inputs into a 2-D tensor + # Then, as the output is transposed, we need to propagate the transposed + # stride information to the output tensor + assert len(input.shape) == 2, "we can only handle the squashed input case" + transposed_strides = (1, input.size(0)) + + output = input.new_empty( + output_sizes, + dtype=input.dtype if input.dtype != torch.int8 else torch.int32, + ).as_strided(output_sizes, transposed_strides) + + return output + + @register_meta(aten.index_reduce.default) def meta_index_reduce( self: Tensor,