-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: activation checkpointingRelated to activation checkpointingRelated to activation checkpointing
Description
🐛 Bug
I'm attempting to use torch.utils.checkpoint.checkpoint
. I've found that it fails to properly call of CheckpointFunction.backward
in some cases.
To Reproduce
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class Model(nn.Module):
def __init__(self, n: int, use_cp: bool, allow_first: bool, allow_last: bool):
super().__init__()
self.layers = nn.ModuleList()
self.n = n
self.use_cp = use_cp
self.allow_first = allow_first
self.allow_last = allow_last
for i in range(self.n):
self.layers.append(nn.Linear(256, 256))
def forward(self, x):
for i in range(self.n):
if (
not self.use_cp
or (i == 0 and not self.allow_first)
or (i == self.n - 1 and not self.allow_last)
):
print("No checkpoint", i)
x = self.layers[i](x)
else:
print("Checkpointing", i)
x = checkpoint(self.layers[i], x)
return x
def test(use_cp, first, last):
model = Model(4, use_cp, first, last).cuda()
x = torch.randn(17, 256).cuda()
loss = model(x).sum()
try:
loss.backward()
except RuntimeError:
return "RuntimeError"
return sum([p.grad is None for p in model.parameters()])
print("None grads with NO grad checkpoint:", test(False, False, False))
print()
print("None grads with ALL grad checkpoint (1..n):", test(True, True, True))
print()
print("None grads with grad checkpoint (no first; 2..n):", test(True, False, True))
print()
print("None grads with grad checkpoint (no last; 1..n-1):", test(True, True, False))
print()
print("None grads with grad checkpoint (neither; 2..n-1):", test(True, False, False))
print()
Produces the output:
No checkpoint 0
No checkpoint 1
No checkpoint 2
No checkpoint 3
None grads with NO grad checkpoint: 0
Checkpointing 0
/private/home/roller/.conda/envs/chat/lib/python3.7/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Checkpointing 1
Checkpointing 2
Checkpointing 3
None grads with ALL grad checkpoint (1..n): RuntimeError
No checkpoint 0
Checkpointing 1
Checkpointing 2
Checkpointing 3
None grads with grad checkpoint (no first; 2..n): 0
Checkpointing 0
Checkpointing 1
Checkpointing 2
No checkpoint 3
None grads with grad checkpoint (no last; 1..n-1): 6
No checkpoint 0
Checkpointing 1
Checkpointing 2
No checkpoint 3
None grads with grad checkpoint (neither; 2..n-1): 0
Expected behavior
All test cases should produce 0 None grads, or at least RuntimeErrors. The fourth case, in which only 3/4 layers have gradients produces is VERY concerning.
Environment
$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 418.116.00
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] numpy-stubs==0.0.1
[pip] pytorch-pretrained-bert==0.6.2
[pip] pytorch-transformers==1.1.0
[pip] torch==1.4.0
[pip] torchtext==0.4.0
[pip] torchvision==0.5.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.0.2 py37h7b6447c_0
[conda] mkl_fft 1.0.14 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.3.1 py3.7_cuda10.0.130_cudnn7.6.3_0 pytorch
[conda] pytorch-pretrained-bert 0.6.1 <pip>
[conda] pytorch-pretrained-bert 0.6.2 <pip>
[conda] pytorch-transformers 1.1.0 <pip>
[conda] torch 1.4.0 <pip>
[conda] torchtext 0.4.0 <pip>
[conda] torchvision 0.4.2 py37_cu100 pytorch
[conda] torchvision 0.5.0 <pip>
ngoyal2707 and impiga
Metadata
Metadata
Assignees
Labels
module: activation checkpointingRelated to activation checkpointingRelated to activation checkpointing