## Explanation of the MLP Forward and Backward Pass

### 1. Define Model Architecture and Initialize Parameters

- We define the dimensions of the input, hidden, and output layers:
  - `input_size = 10`
  - `hidden_size = 20`
  - `output_size = 1`

In [2]:
import torch
import torch.nn.functional as F

# Define the dimensions of the input, hidden, and output layers
input_size = 10
hidden_size = 20
output_size = 1

- We initialize the weights and biases for the MLP using PyTorch tensors:
  - `w1`: Weight matrix for the first layer (size: `input_size x hidden_size`)
  - `b1`: Bias vector for the first layer (size: `1 x hidden_size`)
  - `w2`: Weight matrix for the second layer (size: `hidden_size x output_size`)
  - `b2`: Bias vector for the second layer (size: `1 x output_size`)

In [3]:
# Initialize weights and biases for the MLP
w1 = torch.randn(input_size, hidden_size, requires_grad=True)  # Weight matrix for first layer
b1 = torch.randn(1, hidden_size, requires_grad=True)  # Bias vector for first layer
w2 = torch.randn(hidden_size, output_size, requires_grad=True)  # Weight matrix for second layer
b2 = torch.randn(1, output_size, requires_grad=True)  # Bias vector for second layer

### 2. Forward Pass

- We define the input tensor (`input_tensor`) and target output tensor (`target_output`).
- We perform the forward pass to compute the predicted output tensor (`output_tensor`):
  - Linear transformation and ReLU activation for the first hidden layer
  - Linear transformation for the output layer
- We compute the mean squared error loss (`loss`) between the predicted output and the target output.


In [6]:
input_tensor

tensor([[-0.8218, -0.5395, -0.3298,  0.6342,  0.2875,  1.3541, -1.7218, -0.9154,
          0.4241,  0.3175]])

In [7]:
# Forward pass
hidden_output = torch.matmul(input_tensor, w1) + b1
hidden_output = F.relu(hidden_output)
output_tensor = torch.matmul(hidden_output, w2) + b2
loss = F.mse_loss(output_tensor, target_output)
loss

tensor(5.7385, grad_fn=<MseLossBackward0>)

### 3. Backward Pass

- We compute the gradients of the loss with respect to each parameter using the chain rule and manual differentiation:
  - Gradient of the loss with respect to the output tensor (`output_grad`)
  - Gradients of the loss with respect to `w2` and `b2`
  - Gradient of the loss with respect to the hidden layer output
  - Gradient of the loss with respect to the hidden layer activation (ReLU derivative)
  - Gradients of the loss with respect to `w1` and `b1`
- We print the computed gradients.

In [5]:
# Backward pass
loss_grad = torch.tensor(1.0)  # Gradient of the loss with respect to itself

# Gradient of loss with respect to output tensor
output_grad = 2 * (output_tensor - target_output) / output_tensor.size(0)

# Gradient of loss with respect to w2 and b2
w2_grad = torch.matmul(hidden_output.t(), output_grad)
b2_grad = output_grad.sum(0, keepdim=True)

# Gradient of loss with respect to hidden layer output
hidden_output_grad = torch.matmul(output_grad, w2.t())

# Gradient of loss with respect to hidden layer activation
hidden_output_grad[hidden_output < 0] = 0  # ReLU derivative (ReLU'(x) = 0 if x < 0)

# Gradient of loss with respect to w1 and b1
w1_grad = torch.matmul(input_tensor.t(), hidden_output_grad)
b1_grad = hidden_output_grad.sum(0, keepdim=True)

# Print gradients
print("Gradients of w1:")
print(w1_grad)
print("Gradients of b1:")
print(b1_grad)
print("Gradients of w2:")
print(w2_grad)
print("Gradients of b2:")
print(b2_grad)

Gradients of w1:
tensor([[  4.1775,   1.5801,   9.7490,  -0.4912,  -4.4792,  -4.2421,  -1.9519,
          -4.9478,  -0.9522,  -4.6351,   3.7296,   3.1289,   0.6668,  -2.6608,
          -1.2238,   3.9295,   2.8678,   7.4332,  -3.9013,   0.9898],
        [  2.7425,   1.0373,   6.4000,  -0.3225,  -2.9405,  -2.7849,  -1.2814,
          -3.2481,  -0.6251,  -3.0429,   2.4484,   2.0541,   0.4377,  -1.7467,
          -0.8034,   2.5796,   1.8827,   4.8797,  -2.5611,   0.6498],
        [  1.6765,   0.6341,   3.9124,  -0.1971,  -1.7976,  -1.7024,  -0.7833,
          -1.9856,  -0.3821,  -1.8601,   1.4968,   1.2557,   0.2676,  -1.0678,
          -0.4911,   1.5770,   1.1509,   2.9830,  -1.5657,   0.3972],
        [ -3.2241,  -1.2195,  -7.5240,   0.3791,   3.4569,   3.2739,   1.5064,
           3.8185,   0.7349,   3.5772,  -2.8784,  -2.4148,  -0.5146,   2.0535,
           0.9445,  -3.0327,  -2.2133,  -5.7367,   3.0109,  -0.7639],
        [ -1.4613,  -0.5527,  -3.4103,   0.1718,   1.5669,   1.4839,   