<a href="https://colab.research.google.com/github/sriharikrishna/siamcse23/blob/main/pytorch_seeding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Custom PyTorch functions and seeding

When using PyTorch out of the box, we usually don't need to think about seeding. This is because we normally have a scalar-valued loss function, and calling `backward()` on its result will internally create a "seed", although that seed is of size $1$ and is really just one single number: `1.0`
Given two successive layers $L_1$ and $L_2$ within the same network, where the output of $L_1$ is used as the input to $L_2$, back-propagation will automatically use the gradient of the loss function with respect to the inputs of $L_2$ as a seed for the gradient computation of $L_1$. All of this happens in the background without users having to know about it.

As soon as we are dealing with loss functions that are not scalar-valued (that is, loss functions with multiple output values), or as soon as we are manually implementing the `backward()` pass of an intermediate function or layer within a computation for which we compute gradients, then we need to understand seeding.

Let's begin by importing torch and numpy.

In [3]:
import torch
import numpy as np

Next, we create a small neural network. Note how the constructor allows us to swap out one of the layer types, we will use this to insert our custom layer later.

In [9]:
class Net(torch.nn.Module):
    def __init__(self, customlayer):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 4, 3)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(4, 5, 3)
        self.custom = customlayer
        self.max = torch.nn.MaxPool2d(2)
        self.linear = torch.nn.Linear(5*6*6, 7)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.custom(x)
        x = self.max(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

Now we are going to create a simple custom layer that can be used within a PyTorch model. The `forward()` method computes the inference step, and its output is simply the element-wise square of the input. The `backward()` method computes the back-propagation. Given that the inference step computes

$$y=x^2,$$

we know the derivative to be

$$\frac{\partial y}{\partial x} = 2\cdot x.$$

This needs to be multiplied with the gradients of the subsequent layer, which is passed to us through `grad_output`. The result of all this will be returned by us as `fwdinput`:

$$\bar{x} = \frac{\partial y}{\partial x}\cdot\bar{y} = 2\cdot x\cdot\bar{y}.$$

Let's remember that in the above, $\bar{y}$ represents the gradient of the loss function with respect to $y$, whereas $\bar{x}$ represents the gradient of the loss function with respect to $x$.


In [16]:
class CustomLayer(torch.nn.Module):
    class __Func__(torch.autograd.Function):
        @staticmethod
        def forward(ctx, fwdinput):
            ctx.save_for_backward(fwdinput)
            return fwdinput**2
        
        @staticmethod
        def backward(ctx, grad_output):
            fwdinput, = ctx.saved_tensors
            grad_input = 2*fwdinput*grad_output
            return grad_input

    def __init__(self):
        super(CustomLayer, self).__init__()

    def forward(self, x):
        return self.__Func__.apply(x)

To check whether we did this correctly, we can use the built-in `gradcheck` routine that ships with Torch.

In [17]:
from torch.autograd import gradcheck
input = torch.randn(500,1,dtype=torch.double,requires_grad=True)
test = gradcheck(CustomLayer.__Func__.apply, input, eps=1e-6, atol=1e-4)
print(test)

True


Feel free to modify the forward or backward pass and see if the gradcheck still works.

Now let's create a neural net using this custom layer, and compute its gradients for some random input point.

In [12]:
n_batches = 2
nninput = torch.randn(n_batches, 3, 16, 16)
tgt = torch.randn((n_batches, 7))

net = Net(CustomLayer())
net.zero_grad()
out = net(nninput)
loss = torch.nn.MSELoss(reduction='sum')
l = loss(out, tgt)
l.backward()

print(f"grad_conv1_bias:\n{net.conv1.bias.grad}")
print(f"grad_conv2_bias:\n{net.conv2.bias.grad}")
#print(f"grad_conv1_weight:\n{net.conv1.weight.grad}")
#print(f"grad_conv2_weight:\n{net.conv2.weight.grad}")

grad_conv1_bias:
tensor([ 0.2400, -1.0586,  0.9492, -2.2431])
grad_conv2_bias:
tensor([ 0.2569, -0.1768,  2.1923, -1.3565,  3.4123])


Feel free to un-comment the print statements for gradients with respect to weight of each layer, but beware that the output is a little lengthy.