Skip to content

Some operations crash autograd if parameter size is changed. #33941

@RobertCsordas

Description

@RobertCsordas

🐛 Bug

Some operations cache something in the backward pass, so changing parameter size crashes autograd. This happens only if there was already a backward pass, never on the first.

Not all operations crash. Cat is just an example I found by ablating my code until I found something that breaks down. The same code works fine without the cat. Chaining matrix multiplications and removing rows/columns from matrices works just fine too.

The crash happens in this line: https://github.com/pytorch/pytorch/blob/v1.4.0/torch/csrc/autograd/engine.cpp#L448

The graph is not retained, so I don't get where and why the metadata is stored.
If I create a new parameter instead, everything works just fine.

This is clearly a bug as:

  • If it is not allowed to change parameter sizes, then set_ should check whether the sizes match and fail if they don't, instead of getting this cryptic error message which is very hard to find where it comes from
  • If it is allowed to change sizes, then the autograd cache should adapt accordingly.

Allowing size change would be very useful for weight pruning experiments (I am aware that the optimizer state should be changed accordingly, but it is still much easier to do it with set_ than with re-creating all the parameters).

To Reproduce

Run this:

import torch

asd = torch.nn.Parameter(torch.ones(16))

for i in range(2):
    print(f"Round {i}")
    with torch.no_grad():
        asd.set_(asd[1:])
        asd.grad=None

    m = torch.cat((asd, asd))
    m.sum().backward()

Output:

$ python3 fail.py 
Round 0
Round 1
Traceback (most recent call last):
  File "fail.py", line 12, in <module>
    m.sum().backward()
  File "/home/robert/.local/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/robert/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function CatBackward returned an invalid gradient at index 0 - got [14] but expected shape compatible with [15]

Expected behavior

Do not fail.

Environment

PyTorch version: 1.4.0a0+7f73f1d
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Manjaro Linux
GCC version: (GCC) 9.2.0
CMake version: version 3.16.4

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce GTX TITAN X

Nvidia driver version: 440.59
cuDNN version: /usr/lib/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.14.5
[pip3] torch==1.4.0a0+7f73f1d
[pip3] torchfile==0.1.0
[pip3] torchvision==0.5.0
[conda] Could not collect

Note: PyTorch is built from source, from tags/v1.4.0.

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen

Metadata

Metadata

Assignees

Labels

high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions