Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor]
_sparse_semi_structured_linear
fallback - no meta regist…
…ration; not on testing path (#114477) Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration. When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped. Is this something to do with the generated inductor code for: ``` constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) ... slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None ``` and ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None ... [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None ``` Failing graphs: Padded: ``` [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 5 ===== [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.66 class GraphModule(torch.nn.Module): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"): [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2); constant_pad_nd = primals_1 = primals_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1); _sparse_semi_structured_linear = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2); slice_2 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]); primals_3 = None [2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, le, permute] ``` Reshape: ``` [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.69 class GraphModule(torch.nn.Module): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"): [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] mul: "Sym(s0*s1)" = primals_4 * primals_5 [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]); primals_6 = mul = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3); primals_1 = primals_2 = primals_3 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807); _sparse_semi_structured_linear = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]); slice_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1); view_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu) [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias); alias = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0); alias_1 = None [2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] return [relu, view, le, primals_4, primals_5] ``` Pull Request resolved: #114477 Approved by: https://github.com/jcaip
- Loading branch information