Skip to content

torch.compile Fails with AssertionError on Models Using to_sparse()torch.sparse.mm()to_dense() Sequence #164697

@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

🐛 Describe the bug

When compiling a model that performs the sequence dense.to_sparse() → torch.sparse.mm() → sparse.to_dense(), an internal AssertionError occurs. The identical model works correctly in eager mode.

To reproduce

import torch.nn as nn
import torch


class TestModel(torch.nn.Module):

    def forward(self, x):
        x_sparse = x.to_sparse()
        weights = torch.randn(64, 128)
        mm_res = torch.sparse.mm(x_sparse, weights)
        dense_res = mm_res.to_dense()
        return dense_res

x = torch.randn(32, 64)
model = TestModel()
print("Eager output:", model(x))
print("Compiled output:", torch.compile(model)(x))

Error logs

Eager output: tensor([[  6.6635, -11.0075,  -9.1316,  ...,   2.3962,   0.5756,  -8.6663],
        [-20.1618,  -7.0493,  -8.6446,  ...,  -4.0912,  -8.2432,   3.1041],
        [-10.4155,   3.1859,  -4.0174,  ...,  -0.5116,  -2.2745,   8.8236],
        ...,
        [  4.6425,  23.3313,  -8.6384,  ...,  -1.0809,  -9.2113,   4.0540],
        [  7.3613,   8.1121,  -1.0245,  ...,   1.3643,  -9.7374,   1.0857],
        [-10.9416, -14.5704,   0.8874,  ...,   2.3209,  -0.5671,   3.1823]])
Traceback (most recent call last):
  File "E:\DL_Compiler_Test\torch_code\test.py", line 24, in <module>
    print("Compiled output:", torch.compile(model)(x))
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\output_graph.py", line 2196, in _call_user_compiler
    raise BackendCompilerFailed(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\output_graph.py", line 2171, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\__init__.py", line 2380, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_inductor\compile_fx.py", line 2681, in compile_fx
    return aot_autograd(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\backends\common.py", line 117, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_functorch\aot_autograd.py", line 1096, in aot_module_simplified
    aot_state = create_aot_state(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_functorch\aot_autograd.py", line 567, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_functorch\_aot_autograd\collect_metadata_analysis.py", line 834, in inner
    fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\utils\_pytree.py", line 1376, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\utils\_pytree.py", line 1193, in unflatten
    leaves = list(leaves)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_functorch\_aot_autograd\functional_utils.py", line 72, in from_fun
    assert not torch._is_functional_tensor(t)  # type: ignore[attr-defined]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

Collecting environment information...
PyTorch version: 2.10.0.dev20251005+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 4.0.2
Libc version: N/A

Python version: 3.10.10 (tags/v3.10.10:aad5f6a, Feb 7 2023, 17:20:36) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @peterjc123 @mszhanyi @skyline75489 @nbcsm @iremyux @Blackhex @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: sparseRelated to torch.sparsemodule: windowsWindows support for PyTorchoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions