Skip to content

Commit

Permalink
[inductor] _sparse_semi_structured_linear fallback - no meta regist…
Browse files Browse the repository at this point in the history
…ration; not on testing path (#114477)

Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration.

When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped.

Is this something to do with the generated inductor code for:
```
 constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
...
slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
```
and

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
...
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
```

Failing graphs:
Padded:
```
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 5 =====
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.66 class GraphModule(torch.nn.Module):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2);  constant_pad_nd = primals_1 = primals_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2);  slice_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, le, permute]

```

Reshape:

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.69 class GraphModule(torch.nn.Module):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3);  primals_1 = primals_2 = primals_3 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807);  _sparse_semi_structured_linear = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1);  view_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, view, le, primals_4, primals_5]

```

Pull Request resolved: #114477
Approved by: https://github.com/jcaip
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Nov 28, 2023
1 parent ddf1cb7 commit cef79c0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))

model = torch.compile(model)
model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
_test_mlp_contiguous_relu_compile(backend, dense_input_shape)
for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
for dense_input_shape in [(128, 128), (64, 128), (1, 128), (64, 128, 128)]:
_test_mlp_contiguous_relu_compile(backend, dense_input_shape)

class TestSparseSemiStructured(TestCase):

Expand Down
29 changes: 29 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,35 @@ def meta_unsqueeze_(self, dim):
return self


@register_meta(aten._sparse_semi_structured_linear)
def meta_sparse_structured_linear(
input: Tensor,
weight: Tensor,
_meta: Tensor,
bias: Optional[Tensor] = None,
_activation_opt: Optional[str] = None,
):
output_sizes = list(input.shape)
if bias is not None:
assert weight.size(0) == bias.size(0), "output size mismatch"
assert weight.size(1) == input.size(-1) / 2
output_sizes[-1] = weight.size(0)

# see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
# We assume that we have already squashed the inputs into a 2-D tensor
# Then, as the output is transposed, we need to propagate the transposed
# stride information to the output tensor
assert len(input.shape) == 2, "we can only handle the squashed input case"
transposed_strides = (1, input.size(0))

output = input.new_empty(
output_sizes,
dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
).as_strided(output_sizes, transposed_strides)

return output


@register_meta(aten.index_reduce.default)
def meta_index_reduce(
self: Tensor,
Expand Down

0 comments on commit cef79c0

Please sign in to comment.