Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes bug in sspaddmm (#45113) #45963

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/sparse/SparseTensorMath.cpp
Expand Up @@ -1084,7 +1084,8 @@ SparseTensor& _sspaddmm_out_cpu(
"sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1));

int64_t nnz = sparse._nnz();
LongTensor indices = sparse._indices();
// We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
LongTensor indices = sparse._indices().contiguous();
Tensor values = sparse._values();

LongTensor csr = _to_csr(indices.data_ptr<int64_t>(), dim_i, nnz);
Expand Down
40 changes: 40 additions & 0 deletions test/test_sparse.py
Expand Up @@ -1080,6 +1080,46 @@ def test_shape(di, dj, dk, nnz):
test_shape(1000, 0, 100, 0)
test_shape(1000, 100, 0, 0)

@cpu_only
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
def test_sspaddmm(self):

def test_shape(di, dj, dk, nnz):
x = self._gen_sparse(2, nnz, [di, dj])[0]
t = self._gen_sparse(2, nnz, [di, dk])[0]
y = torch.randn(dj, dk)
alpha = random.random()
beta = random.random()

res = t.sspaddmm(x, y, beta=beta, alpha=alpha)
expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha)
self.assertEqual(self.safeToDense(res), expected)

res = t.sspaddmm(x, y)
expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
self.assertEqual(self.safeToDense(res), expected)

test_shape(7, 5, 3, 20)
test_shape(1000, 100, 100, 20)
test_shape(3000, 64, 300, 20)
test_shape(0, 100, 100, 0)
test_shape(1000, 0, 100, 0)
test_shape(1000, 100, 0, 0)

# Test issue https://github.com/pytorch/pytorch/issues/45113
batch_size, input_size, hidden_size = 5, 3, 7
weight = torch.randn(hidden_size, input_size).to_sparse()
bias = torch.randn((hidden_size, 1)).to_sparse()
bias = torch.cat([bias] * batch_size, dim=1)
if not self.is_uncoalesced:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
weight = weight.coalesce()
bias = bias.coalesce()

x = torch.randn(batch_size, input_size)
y = bias.sspaddmm(weight, x.t())
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

y_bis = (bias.to_dense() + torch.matmul(weight.to_dense(), x.t())).to_sparse()
self.assertLess((y.to_dense() - y_bis.to_dense()).abs().max().item(), 1e-6)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

def test_sparse_addmm(self):
def test_shape(m, n, p, nnz, broadcast):
if broadcast:
Expand Down