### Checkpointing

##### Example 1

In [None]:
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential

In [None]:
class CheckPoint(nn.Sequential):
    def forward(self, *input):
        return checkpoint(super().forward, *input)

In [None]:
class Echo(nn.Module):
    def __init__(self, msg: str):
        super().__init__()
        self.msg = msg  # print this message during forward (for debugging)
    
    def forward(self, x):
        print("forward", self.msg)
        return x

In [None]:
model = nn.Sequential(
    CheckPoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer1 done")),
    CheckPoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer2 done"))
)

In [None]:
inputs = torch.randn(16, 1000, requires_grad=True)

In [None]:
output = model(inputs)

forward layer1 done
forward layer2 done


In [None]:
output

tensor([[0.2257, 0.1056, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1090, 0.0000,  ..., 0.0000, 0.0000, 0.2023],
        [0.0000, 0.1253, 0.0000,  ..., 0.0000, 0.0000, 0.4055],
        ...,
        [0.0000, 0.0421, 0.0000,  ..., 0.0000, 0.5109, 0.0000],
        [0.0000, 0.2375, 0.0000,  ..., 0.1900, 0.0582, 0.2539],
        [0.0000, 0.2023, 0.0000,  ..., 0.0000, 0.0000, 0.2167]],
       grad_fn=<CheckpointFunctionBackward>)

In [None]:
output.norm().backward()

forward layer2 done
forward layer1 done
