https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph

In a forward pass, autograd does two things simultaneously:

- run the requested operation to compute a resulting tensor, and
- maintain the operation’s _gradient function_ in the DAG.

The backward pass kicks off when `.backward()` is called on the DAG root. `autograd` then:

- computes the gradients from each `.grad_fn`,
- accumulates them in the respective tensor’s `.grad` attribute, and
- using the chain rule, propagates all the way to the leaf tensors.


In [1]:
import torch as t
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, layer_sizes=[16, 128, 128, 16], bias=False, device=None):
        super().__init__()
        self.model = nn.Sequential(
            *[nn.Linear(size, layer_sizes[i+1], bias, device) for i, size in enumerate(layer_sizes[:-1])]
        )
    
    def forward(self, x):
        return self.model(x)

mem = lambda gpu: t.cuda.memory_allocated(gpu) / 2**(30) 
mems = lambda gpu: f'{gpu} memory usage: {mem(gpu):.2f} GiB'

In [2]:
layer_sizes = [16, 2**5, 2**10, 2**20, 2**10, 2**5, 16]
gpu = 'cuda:0'

X = t.randn((layer_sizes[0],), device=gpu)
y = t.randn((layer_sizes[-1],), device=gpu)
print('1', mems(gpu))
    
mlp = MLP(layer_sizes, device=gpu)
print('2', mems(gpu))

y_pred = mlp(X)
print('3', mems(gpu))

loss_fn = nn.MSELoss()
loss = loss_fn(y, y_pred)
print('4', mems(gpu))
# p.grad == None for p in mlp.parameters()

loss.backward() # creates p.grad for p in mlp.parameters()
print('5', mems(gpu))



1 cuda:0 memory usage: 0.00 GiB
2 cuda:0 memory usage: 8.00 GiB
3 cuda:0 memory usage: 8.00 GiB
4 cuda:0 memory usage: 8.00 GiB
5 cuda:0 memory usage: 16.00 GiB


In [3]:
[p.grad.shape for p in mlp.parameters()]

[torch.Size([32, 16]),
 torch.Size([1024, 32]),
 torch.Size([1048576, 1024]),
 torch.Size([1024, 1048576]),
 torch.Size([32, 1024]),
 torch.Size([16, 32])]

In [6]:
# Clear everything from the GPU
del X, y, mlp, y_pred, loss_fn, loss
mems(gpu)

'cuda:0 memory usage: 0.00 GiB'