-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Open
Labels
module: sparseRelated to torch.sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
One AdamSparse step for optimizing a model with a sparse parameter matrix gives me a RuntimeError
To Reproduce
# %%
import torch
from torch import nn
class TrainNet(nn.Module):
def __init__(self, in_features, out_features):
super(TrainNet, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features).to_sparse().requires_grad_(True))
def forward(self, input):
x = torch.sparse.mm(self.weight, input)
return x
model = TrainNet(10, 20)
opt = torch.optim.SparseAdam(model.parameters(), 0.01)
inp = torch.ones((10, 30), dtype=torch.float32)
model.train()
model.zero_grad()
out = model(inp)
out.sum().backward()
opt.step()
gives me the error:
Traceback (most recent call last):
File "<input>", line 25, in <module>
File "/Users/antonio/miniconda3/envs/python3_6/lib/python3.6/site-packages/torch/optim/sparse_adam.py", line 86, in step
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Environment
PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: version 3.15.4
Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] torch==1.3.1
[pip] torchvision==0.4.2
[conda] _tflow_select 2.3.0 mkl
[conda] blas 1.0 mkl
[conda] mkl 2019.4 233
[conda] mkl-service 2.3.0 py36hfbe908c_0
[conda] mkl_fft 1.0.14 py36h5e564d8_0
[conda] mkl_random 1.1.0 py36ha771720_0
[conda] pytorch 1.3.1 py3.6_0 pytorch
[conda] tensorflow 1.14.0 mkl_py36h933f829_0
[conda] tensorflow-base 1.14.0 mkl_py36h655c25b_0
[conda] torchvision 0.4.2 py36_cpu pytorch
Additional context
I did manage to fix Sparse Adam to run in this minimal example by:
- Doing as proposed in https://discuss.pytorch.org/t/pytorch-sparse-adam-how-to-run/39871/2
and changing onsparse_adam.py
:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
to:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data.to_dense())
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data.to_dense())
- Similar to SGD fails on sparse matrix #29814, changing:
p.data.add_(make_sparse(-step_size * numer.div_(denom)))
to:
with torch.no_grad():
p.add_(make_sparse(-step_size * numer.div_(denom)))
on sparse_adam.py
cc @vincentqb
Metadata
Metadata
Assignees
Labels
module: sparseRelated to torch.sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module