Skip to content

Gradient checkpointing fails to backprop in some cases #33962

@stephenroller

Description

@stephenroller

🐛 Bug

I'm attempting to use torch.utils.checkpoint.checkpoint. I've found that it fails to properly call of CheckpointFunction.backward in some cases.

To Reproduce

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class Model(nn.Module):
    def __init__(self, n: int, use_cp: bool, allow_first: bool, allow_last: bool):
        super().__init__()
        self.layers = nn.ModuleList()
        self.n = n
        self.use_cp = use_cp
        self.allow_first = allow_first
        self.allow_last = allow_last
        for i in range(self.n):
            self.layers.append(nn.Linear(256, 256))

    def forward(self, x):
        for i in range(self.n):
            if (
                not self.use_cp
                or (i == 0 and not self.allow_first)
                or (i == self.n - 1 and not self.allow_last)
            ):
                print("No checkpoint", i)
                x = self.layers[i](x)
            else:
                print("Checkpointing", i)
                x = checkpoint(self.layers[i], x)
        return x


def test(use_cp, first, last):
    model = Model(4, use_cp, first, last).cuda()
    x = torch.randn(17, 256).cuda()
    loss = model(x).sum()
    try:
        loss.backward()
    except RuntimeError:
        return "RuntimeError"
    return sum([p.grad is None for p in model.parameters()])


print("None grads with NO grad checkpoint:", test(False, False, False))
print()
print("None grads with ALL grad checkpoint (1..n):", test(True, True, True))
print()
print("None grads with grad checkpoint (no first; 2..n):", test(True, False, True))
print()
print("None grads with grad checkpoint (no last; 1..n-1):", test(True, True, False))
print()
print("None grads with grad checkpoint (neither; 2..n-1):", test(True, False, False))
print()

Produces the output:

No checkpoint 0
No checkpoint 1
No checkpoint 2
No checkpoint 3
None grads with NO grad checkpoint: 0

Checkpointing 0
/private/home/roller/.conda/envs/chat/lib/python3.7/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Checkpointing 1
Checkpointing 2
Checkpointing 3
None grads with ALL grad checkpoint (1..n): RuntimeError

No checkpoint 0
Checkpointing 1
Checkpointing 2
Checkpointing 3
None grads with grad checkpoint (no first; 2..n): 0

Checkpointing 0
Checkpointing 1
Checkpointing 2
No checkpoint 3
None grads with grad checkpoint (no last; 1..n-1): 6

No checkpoint 0
Checkpointing 1
Checkpointing 2
No checkpoint 3
None grads with grad checkpoint (neither; 2..n-1): 0

Expected behavior

All test cases should produce 0 None grads, or at least RuntimeErrors. The fourth case, in which only 3/4 layers have gradients produces is VERY concerning.

Environment

$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 418.116.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] numpy-stubs==0.0.1
[pip] pytorch-pretrained-bert==0.6.2
[pip] pytorch-transformers==1.1.0
[pip] torch==1.4.0
[pip] torchtext==0.4.0
[pip] torchvision==0.5.0
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.4                      243
[conda] mkl-service               2.0.2            py37h7b6447c_0
[conda] mkl_fft                   1.0.14           py37ha843d7b_0
[conda] mkl_random                1.0.2            py37hd81dba3_0
[conda] pytorch                   1.3.1           py3.7_cuda10.0.130_cudnn7.6.3_0    pytorch
[conda] pytorch-pretrained-bert   0.6.1                     <pip>
[conda] pytorch-pretrained-bert   0.6.2                     <pip>
[conda] pytorch-transformers      1.1.0                     <pip>
[conda] torch                     1.4.0                     <pip>
[conda] torchtext                 0.4.0                     <pip>
[conda] torchvision               0.4.2                py37_cu100    pytorch
[conda] torchvision               0.5.0                     <pip>

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions