Skip to content

Sparse CSR tensors crash state_dict() (1.11 nightly) #71652

@Linux-cpp-lisp

Description

@Linux-cpp-lisp

🐛 Describe the bug

This code works:

import torch

class SparseTensorModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("a", torch.eye(3).to_sparse())

    def forward(self):
        pass

s = SparseTensorModule()
s.a
print("hi")
print(s.state_dict())

producing:

hi
OrderedDict([('a', tensor(indices=tensor([[0, 1, 2],
                       [0, 1, 2]]),
       values=tensor([1., 1., 1.]),
       size=(3, 3), nnz=3, layout=torch.sparse_coo))])

but modifying

self.register_buffer("a", torch.eye(3).to_sparse())

to

self.register_buffer("a", torch.eye(3).to_sparse_csr())

gives

hi
Segmentation fault (core dumped)

Versions

PyTorch version 1.11.0.dev20220119
CUDA available
Ubuntu Linux 21.10

cc @nikitaved @pearu @cpuhrsch

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: crashProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions