Skip to content

AdamSparse fails to run #29816

@antonior92

Description

@antonior92

🐛 Bug

One AdamSparse step for optimizing a model with a sparse parameter matrix gives me a RuntimeError

To Reproduce

# %%
import torch
from torch import nn

class TrainNet(nn.Module):
    def __init__(self, in_features, out_features):
        super(TrainNet, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.randn(out_features, in_features).to_sparse().requires_grad_(True))


    def forward(self, input):
        x = torch.sparse.mm(self.weight, input)
        return x


model = TrainNet(10, 20)
opt = torch.optim.SparseAdam(model.parameters(), 0.01)
inp = torch.ones((10, 30), dtype=torch.float32)
model.train()

model.zero_grad()
out = model(inp)
out.sum().backward()
opt.step()

gives me the error:

Traceback (most recent call last):
  File "<input>", line 25, in <module>
  File "/Users/antonio/miniconda3/envs/python3_6/lib/python3.6/site-packages/torch/optim/sparse_adam.py", line 86, in step
    old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Environment

PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: version 3.15.4

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] torch==1.3.1
[pip] torchvision==0.4.2
[conda] _tflow_select             2.3.0                       mkl  
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.4                      233  
[conda] mkl-service               2.3.0            py36hfbe908c_0  
[conda] mkl_fft                   1.0.14           py36h5e564d8_0  
[conda] mkl_random                1.1.0            py36ha771720_0  
[conda] pytorch                   1.3.1                   py3.6_0    pytorch
[conda] tensorflow                1.14.0          mkl_py36h933f829_0  
[conda] tensorflow-base           1.14.0          mkl_py36h655c25b_0  
[conda] torchvision               0.4.2                  py36_cpu    pytorch

Additional context

I did manage to fix Sparse Adam to run in this minimal example by:

  1. Doing as proposed in https://discuss.pytorch.org/t/pytorch-sparse-adam-how-to-run/39871/2
    and changing on sparse_adam.py:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

to:

# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data.to_dense())
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data.to_dense())
  1. Similar to SGD fails on sparse matrix #29814, changing:
p.data.add_(make_sparse(-step_size * numer.div_(denom)))

to:

with torch.no_grad():
    p.add_(make_sparse(-step_size * numer.div_(denom)))

on sparse_adam.py

cc @vincentqb

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions