In [2]:
from types import SimpleNamespace

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

## Let's check what's going on with GC

In [3]:
def print_layer_info(layer, out=None):
    s = ("Layer: {name}\n"
         " w:    {weight}\n"
         " grad: {grad}").format(name=layer, weight=layer.weight.data, grad=layer.weight.grad)
    print(s)
    if out is not None:
        print(f"  out: {out.grad_fn}")

In [4]:
l =nn.Linear(1,1)

In [5]:
print_layer_info(l)

Layer: Linear(in_features=1, out_features=1, bias=True)
 w:    tensor([[-0.2703]])
 grad: None


In [43]:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1,1, bias=False)
        self.linear2 = nn.Linear(1,1, bias=False)
        self.head = nn.Linear(1,1, bias=False)
    
    def forward(self, x, cp=False, reentrant=True, enable_input_grad=False):
        if cp:
            x = checkpoint(self.linear, x, use_reentrant=reentrant)
            if enable_input_grad:
                x.requires_grad_(True)
            print_layer_info(self.linear, x)
            x = checkpoint(self.linear2, x, use_reentrant=reentrant)
            print_layer_info(self.linear, x)
            
        else:
            x = self.linear(x)
            print_layer_info(self.linear, x)
            x = self.linear2(x)
            print_layer_info(self.linear2, x)
        out = self.head(x)
        print_layer_info(self.head, out)
        return out

    def print(self):
        for l in [self.linear, self.linear2, self.head]:
            print_layer_info(l)
        
torch.manual_seed(42)        
model = Model()
x = torch.rand(1,1)

no gradient checkpointing

In [44]:
model(x)

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: <MmBackward0 object at 0x7fb733f943d0>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: None
  out: <MmBackward0 object at 0x7fb733f972e0>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb733f972e0>


tensor([[-0.1426]], grad_fn=<MmBackward0>)

In [45]:
loss = model(x).pow(2).sum()
loss.backward()
loss

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: <MmBackward0 object at 0x7fb733f972e0>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: None
  out: <MmBackward0 object at 0x7fb736d35930>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb736d35930>


tensor(0.0203, grad_fn=<SumBackward0>)

In [46]:
model.print()

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: tensor([[0.0532]])
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: tensor([[0.0490]])
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: tensor([[-0.1736]])


default Gradient Checkpointing with frozen embeddings

In [47]:
torch.manual_seed(42)        
model = Model()
x = torch.rand(1,1)
model.linear.weight.requires_grad_(False);
loss = model(x, cp=True).pow(2).sum()
loss.backward()
loss

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb733f943d0>


tensor(0.0203, grad_fn=<SumBackward0>)

In [48]:
model.print()

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: tensor([[-0.1736]])


This is wrong, we should have gradients on Linear layer!

New Gradient Checkpointing behavior

In [49]:
torch.manual_seed(42)        
model = Model()
x = torch.rand(1,1)
model.linear.weight.requires_grad_(False);
loss = model(x, cp=True, reentrant=False).pow(2).sum()
loss.backward()
loss

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: <MmBackward0 object at 0x7fb731af7fa0>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb731af7fa0>


tensor(0.0203, grad_fn=<SumBackward0>)

In [50]:
model.print()

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: tensor([[0.0490]])
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: tensor([[-0.1736]])


In [51]:
torch.manual_seed(42)        
model = Model()
x = torch.rand(1,1)
model.linear.weight.requires_grad_(False);
loss = model(x, cp=True, reentrant=False, enable_input_grad=True).pow(2).sum()
loss.backward()
loss

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: <MmBackward0 object at 0x7fb731af45e0>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb731af45e0>


tensor(0.0203, grad_fn=<SumBackward0>)

In [52]:
model.print()

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: tensor([[0.0490]])
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: tensor([[-0.1736]])


In [53]:
torch.manual_seed(42)        
model = Model()
x = torch.rand(1,1)
model.linear.weight.requires_grad_(False);
loss = model(x, cp=True, reentrant=True, enable_input_grad=True).pow(2).sum()
loss.backward()
loss

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
  out: <torch.autograd.function.CheckpointFunctionBackward object at 0x7fb7341aae40>
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: None
  out: <MmBackward0 object at 0x7fb731af4d90>


tensor(0.0203, grad_fn=<SumBackward0>)

In [54]:
model.print()

Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.7645]])
 grad: None
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[0.8300]])
 grad: tensor([[0.0490]])
Layer: Linear(in_features=1, out_features=1, bias=False)
 w:    tensor([[-0.2343]])
 grad: tensor([[-0.1736]])
