# Weight Initialization

## Our test model for this practical task

In [None]:
# Use the below functionality to execute your model (that you will adjust later step by step)
# This block of code provides you the functionality to train a model. Results are printed after each epoch

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import tqdm


def load_mnist_data(root_path='./data', batch_size=4):
    """
    Loads MNIST dataset into your directory.
    You can change the root_path to point to a already existing path if you want to safe a little bit of memory :)
    """
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))]
    )

    trainset = torchvision.datasets.MNIST(root=root_path, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root=root_path, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainloader, testloader


def train_model(model, batch_size: int = 4, epochs: int = 10):
    # we only consider the mnist train data for this example
    train_loader, _ = load_mnist_data(root_path='./data', batch_size=batch_size)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = model.to(device=device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    iterations = 0
    for epoch in range(epochs):
        running_loss = 0.0
        running_accuracy = []
        for imgs, targets in tqdm.tqdm(train_loader, desc=f'Training iteration {epoch + 1}'):
            iterations += 1
            imgs, targets = imgs.to(device=device), targets.to(device=device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(imgs.reshape(imgs.shape[0], -1))

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            # Calculate the Accuracy (how many of all samples are correctly classified?)
            max_outputs = torch.max(outputs, dim=1).indices
            accuracy = (max_outputs.detach() == targets.detach()).to(dtype=torch.float32).mean()
            running_accuracy.append(accuracy)
    
        print(f'Epoch {epoch + 1} finished with loss: {running_loss / len(train_loader):.3f} and accuracy {torch.tensor(running_accuracy).mean():.3f}')

## Training progress with different weight settings

In [None]:
# You can use this model for your tests (of course you can change the architecture a little, but it should not be necessary.)
import torch
import torch.nn as nn


model = nn.Sequential(
    nn.Linear(784, 32),  # input layer (do not change the in_features size of this layer - we need it later)
    nn.Linear(32, 32),
    nn.Linear(32, 10)  # you can change the in_features of this layer but let the out_features at size 10 here - we need it layer
)

### Weight settings

In [None]:
# Find out how to change the weights of the layers from your neural network.
# ATTTENTION: Write your code inside the "with torch.no_grad():" section! This is necessary for changing the weights of the layers

#### Zero weights

In [None]:
# Set all weights and biases of your network to zero

with torch.no_grad():
    # Code here
    for layer in model:
        if isinstance(layer, nn.Linear):
            layer.weight.fill_(0.0)
            layer.bias.fill_(0.0)

In [None]:
# Train the network with your new settings and take a look at the results
# Run the model training
train_model(model=model, batch_size=4, epochs=3)

# What can you observe?

# The model does not learn at all. Loss stays constant and accuracy remains around 0.1 (random guessing for 10 classes).
# This is because all neurons compute the same output (symmetry problem) and receive identical gradients,
# so they all update in the same way and remain identical throughout training.

#### Constant weights

In [None]:
# Set all weights and biases to constant numbers (e.g. 0.5)
# How does the training progress?

with torch.no_grad():
    # Code here
    for layer in model:
        if isinstance(layer, nn.Linear):
            layer.weight.fill_(0.5)
            layer.bias.fill_(0.5)

In [None]:
# Train the network with your new settings and take a look at the results
# Run the model training
train_model(model=model, batch_size=4, epochs=3)

# What can you observe?

# Similar to zero weights, the model suffers from the symmetry problem. All neurons in each layer
# have identical weights, so they compute the same outputs and receive the same gradients.
# Training is ineffective because neurons cannot learn different features. The network essentially
# behaves like a single neuron per layer rather than multiple neurons.

In [None]:
# Let us also take a look at the gradient of the output layer
# Access the gradients at the output layer of your model and analyze them

# We first input some random values
# forward + backward
outputs = model(torch.randn(size=(1,784)))
loss = nn.CrossEntropyLoss()(outputs, torch.tensor([1]))
loss.backward()


# Code here

for i, layer in enumerate(model):
    if isinstance(layer, nn.Linear):
        print(f"Layer {i} weight gradients:")
        print(layer.weight.grad)
        print(f"Layer {i} bias gradients:")
        print(layer.bias.grad)

        print()

# What can you observe?

# All gradients for weights in each row are identical because all neurons receive the same input
# and produce the same output due to constant initialization. This confirms the symmetry problem:
# all neurons in a layer have identical gradients and will update identically, preventing the
# network from learning diverse features.

#### Unusual weights

In [None]:
# Set some weights (around 50%) of every model of the model to some weird value, e. g. extremely high (> 10.0) or extremely low (< 1e-7).
# How does the training progress? 
# Can your model also diverge instead of converge because the weights were way to high or low?

with torch.no_grad():
    # Code here
    for layer in model:
        if isinstance(layer, nn.Linear):
            weight_shape = layer.weight.shape
            bias_shape = layer.bias.shape
            
            weight_mask = torch.rand(weight_shape) < 0.5
            bias_mask = torch.rand(bias_shape) < 0.5
            
            layer.weight[weight_mask] = torch.randn(weight_mask.sum()) * 100.0
            layer.weight[~weight_mask] = torch.randn((~weight_mask).sum()) * 1e-8
            
            layer.bias[bias_mask] = torch.randn(bias_mask.sum()) * 100.0
            layer.bias[~bias_mask] = torch.randn((~bias_mask).sum()) * 1e-8

In [None]:
# Train the network with your new settings and take a look at the results
# Run the model training
train_model(model=model, batch_size=4, epochs=5)

# What can you observe?

# The model likely experiences exploding gradients from extremely high weights (causing NaN/inf values)
# or vanishing gradients from extremely low weights (causing no learning). The loss may diverge or remain
# high, and accuracy stays poor. Extreme weight values lead to unstable training where outputs saturate
# activation functions or produce numerical overflow/underflow.

## Weight initialization techniques

In [None]:
# We now take a closer look to the sigmoid activation function.
# Where does the sigmoid function create small gradients and where are the biggest gradients?

# Explanation here

# The sigmoid function has the largest gradients near x=0 (around 0.25) and smallest gradients at extreme values (x >> 0 or x << 0)
# This is because sigmoid derivative is sigma(x) * (1 - sigma(x)), which peaks at x=0


# Now lets plot some different activation function methods
# Use matplotlib and plot the sigmoid activation function into the plot.
# Create 1000 sample points from x-values [-5.0, 5.0] and create y = Sigmoid(x) and plot the result. (The result should simply be the sigmoid curve)
# You can use the Sigmoid function from PyTorch here!

import matplotlib.pyplot as plt

# Code here

x = torch.linspace(-5.0, 5.0, 1000)
y_sigmoid = torch.sigmoid(x)
plt.plot(x.numpy(), y_sigmoid.numpy(), label='Sigmoid curve', color='blue')


# Now lets plot the kaiming normal weight initialization into the plot
# Create 1000 points (x) sampled from the kaiming_normal_ (pytorch function) and create y = Sigmoid(kaiming_normal(1000)) and plot the result into the same plot as before.
# Use a different color for plotting the results


# Code here

kaiming_weights = torch.empty(1000)
nn.init.kaiming_normal_(kaiming_weights)
kaiming_weights_sorted = torch.sort(kaiming_weights)[0]
y_kaiming = torch.sigmoid(kaiming_weights_sorted)
plt.scatter(kaiming_weights_sorted.numpy(), y_kaiming.numpy(), label='Kaiming Normal', color='red', alpha=0.5, s=5)


# Now plot a random normal (torch.randn) weight initialization into the plot
# Create 1000 points (x) sampled from the randn (pytorch function) and create y = Sigmoid(randn(1000)) and plot the result into the same plot as before.
# Use a different color for plotting the results


# Code here

randn_weights = torch.randn(1000)
randn_weights_sorted = torch.sort(randn_weights)[0]
y_randn = torch.sigmoid(randn_weights_sorted)
plt.scatter(randn_weights_sorted.numpy(), y_randn.numpy(), label='Random Normal', color='green', alpha=0.5, s=5)


# Now plot a xavier_normal weight initialization into the plot
# Create 1000 points (x) sampled from the xavier_normal_ (pytorch function) and create y = Sigmoid(xavier_normal_(1000)) and plot the result into the same plot as before.
# Use a different color for plotting the results


# Code here
xavier_weights = torch.empty(1000)
nn.init.xavier_normal_(xavier_weights)
xavier_weights_sorted = torch.sort(xavier_weights)[0]
y_xavier = torch.sigmoid(xavier_weights_sorted)
plt.scatter(xavier_weights_sorted.numpy(), y_xavier.numpy(), label='Xavier Normal', color='orange', alpha=0.5, s=5)



plt.legend()
plt.show()


# Which weight initialization technique is best when using sigmoid activation function?

# Answer here

# Xavier (Glorot) initialization is best for sigmoid activation functions because it keeps the variance of activations
# and gradients stable across layers. It's specifically designed for activation functions with gradients centered around 0.
# Xavier keeps weights in the linear region of sigmoid where gradients are strongest, preventing vanishing gradients.
# Kaiming is designed for ReLU activations, while random normal can cause gradient issues.

## Be creative and test some other weight initialization techniques! - There is so much to explore!