In [22]:
import torch
import torch.nn as nn

In [23]:
class ExampleDeepNeuralNetwork(nn.Module):
    def __init__(self, layer_sizes, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), nn.GELU()),
            nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), nn.GELU()),
            nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), nn.GELU()),
            nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), nn.GELU()),
            nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), nn.GELU())
        ])

    def forward(self, x):
        for layer in self.layers:
            # Compute the output of the current layer
            layer_output = layer(x)
            # Check if shortcut can be applied
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output
            else:
                x = layer_output
        return x

In [24]:
def print_gradients(model, x):
    # Forward pass
    output = model(x)
    target = torch.tensor([[0.]])

    # Calculate loss based on how close the target
    # and output are
    loss = nn.MSELoss()
    loss = loss(output, target)
    
    # Backward pass to calculate the gradients
    loss.backward()

    for name, param in model.named_parameters():
        if 'weight' in name:
            # Print the mean absolute gradient of the weights
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

In [25]:
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1., 0., -1.]])


torch.manual_seed(123) # specify random seed for the initial weights for reproducibility
model_without_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=False)


torch.manual_seed(123) # specify random seed for the initial weights for reproducibility
model_with_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=True)

In [26]:
print_gradients(model_without_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.00020174118981231004
layers.1.0.weight has gradient mean of 0.00012011769285891205
layers.2.0.weight has gradient mean of 0.0007152436301112175
layers.3.0.weight has gradient mean of 0.00139885104727
layers.4.0.weight has gradient mean of 0.005049602594226599


In [27]:

print_gradients(model_with_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.22186797857284546
layers.1.0.weight has gradient mean of 0.20709273219108582
layers.2.0.weight has gradient mean of 0.3292388319969177
layers.3.0.weight has gradient mean of 0.2667771875858307
layers.4.0.weight has gradient mean of 1.3268061876296997
