In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# Define shared parameter MLP class
class SharedParameterMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SharedParameterMLP, self).__init__()
        self.shared_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.shared_layer(x))
        x = torch.relu(self.hidden_layer(x))
        return x

In [3]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x10a1e1f10>

In [4]:
# Define input, hidden, and output sizes
input_size = 10
hidden_size = 20
output_size = 5

# Create an instance of the shared parameter MLP
model = SharedParameterMLP(input_size, hidden_size, output_size)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [5]:
# Dummy input and target tensors for training
input_tensor = torch.randn(2, input_size)  # Batch size of 2
target_tensor = torch.randn(2, output_size)  # Batch size of 2

In [6]:
# Training loop
epochs = 5
for epoch in range(epochs):
    # Forward pass
    output = model(input_tensor)
    
    # Compute the loss
    loss = criterion(output, target_tensor)
    
    # Zero the gradients, perform backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Print model parameters and gradients for each layer
    print(f"Epoch {epoch+1}/{epochs}:")
    for name, param in model.named_parameters():
        print(f"Layer: {name}")
        print(f"   Parameters: {param.data}")
        print(f"   Gradients: {param.grad}")

Epoch 1/5:
Layer: shared_layer.weight
   Parameters: tensor([[ 0.2418,  0.2625, -0.0741,  0.2905, -0.0693,  0.0638, -0.1540,  0.1857,
          0.2788, -0.2320],
        [ 0.2749,  0.0592,  0.2336,  0.0428,  0.1525, -0.0446,  0.2438,  0.0467,
         -0.1476,  0.0806],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2854, -0.2686,  0.2441,  0.0528, -0.1027,  0.1943,  0.0496,  0.2554,
          0.0351, -0.0999],
        [ 0.0850, -0.0858,  0.1331,  0.2823,  0.1828, -0.1382,  0.1825,  0.0566,
          0.1606, -0.1927],
        [-0.3130, -0.1222, -0.2426,  0.2595,  0.0911,  0.1310,  0.1000, -0.0055,
          0.2475, -0.2247],
        [ 0.0199, -0.2158,  0.0975, -0.1089,  0.0969, -0.0659,  0.2623, -0.1874,
         -0.1886, -0.1886],
        [ 0.2844,  0.1054,  0.3043, -0.2610, -0.3137, -0.2474, -0.2127,  0.1281,
          0.1132,  0.2628],
        [-0.1633, -0.2156,  0.1678, -0.1278,  0.1919, -0.0750,  0.1809, -0.