In [105]:
from types import SimpleNamespace

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

## Let's check what's going on with GC

In [232]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(2, 4)
        self.linear = nn.Sequential(nn.Linear(4,4), nn.ReLU())
        self.head = nn.Linear(4,2)
    
    def forward(self, x, cp=False, reentrant=True):
        x = self.embed(x)
        print(f"Embedding: {x.grad_fn}")
        if cp:
            x = checkpoint(self.linear, x, use_reentrant=reentrant)
            print(f"Linear: {x.grad_fn}")
        else:
            x = self.linear(x)
            print(f"Linear: {x.grad_fn}")
        out = self.head(x)
        print(f"Head: {out.grad_fn}")
        return out
        
model = Model().cuda()
x = torch.randint(0, 2, size=(1,4)).cuda()

no gradient checkpointing

In [233]:
loss = model(x).pow(2).sum()
loss.backward()
loss

Embedding: <EmbeddingBackward0 object at 0x7ff979056bc0>
Linear: <ReluBackward0 object at 0x7ff979277670>
Head: <ViewBackward0 object at 0x7ff979277670>


tensor(0.3714, device='cuda:0', grad_fn=<SumBackward0>)

default Gradient Checkpointing with frozen embeddings

In [234]:
model.embed.weight.requires_grad_(False);
loss = model(x, cp=True).pow(2).sum()
loss.backward()
loss

Embedding: None
Linear: None
Head: <ViewBackward0 object at 0x7ff979130bb0>


tensor(0.3714, device='cuda:0', grad_fn=<SumBackward0>)

This is wrong, we should have gradients on Linear layer!

New Gradient Checkpointing behavior

In [235]:
loss = model(x, cp=True, reentrant=False).pow(2).sum()
loss.backward()
loss

Embedding: None
Linear: <ReluBackward0 object at 0x7ff9792762f0>
Head: <ViewBackward0 object at 0x7ff9792762f0>


tensor(0.3714, device='cuda:0', grad_fn=<SumBackward0>)