In [1]:
import torch
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions and so on

In [2]:
# Training Pipline in pytorch
# 1. Design model (input, output size, forward pass)
# 2. Construct loss and optimizer
# 3. Training loop
#  - forward pass: compute prediction
#  - backward pass: gradients
#  - update weights

In [19]:
# Design model
X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)
X_test = torch.tensor([5], dtype=torch.float32)

n_samples, n_features = X.shape
print(f"no. of samples: {n_samples}, no. of features: {n_features}")

input_size = n_features
output_size = n_features

# model = nn.Linear(input_size, output_size)

# Custom model
class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        # layer
        self.lin = nn.Linear(input_dim, output_dim)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin(x)
    
model = LinearRegression(input_size, output_size)

print(f"prediction before training: f(5) = {model(X_test).item():.3f}")

no. of samples: 4, no. of features: 1
prediction before training: f(5) = 3.803


In [20]:
# Construct loss and optimizer
learning_rate = 0.01
n_iters = 100

loss = nn.MSELoss() # Mean Squared Error
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # Stochastic Gradient Descent

In [21]:
# Training loop
for epoch in range(n_iters):
    # prediction = forward pass
    y_pred = model(X)

    # loss
    l = loss(Y, y_pred)

    # gradients = backward pass
    l.backward() # dl/dw

    # update weights
    optimizer.step()

    # zero gradients
    optimizer.zero_grad()

    if epoch % 10 == 0:
        # params
        [w, b] = model.parameters()
        print(f'epoch {epoch + 1}: w = {w[0][0].item():.3f}, loss = {l:.8f}')

print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')

epoch 1: w = 0.950, loss = 11.56701088
epoch 11: w = 1.732, loss = 0.31913894
epoch 21: w = 1.861, loss = 0.02697339
epoch 31: w = 1.884, loss = 0.01832498
epoch 41: w = 1.891, loss = 0.01707521
epoch 51: w = 1.895, loss = 0.01607661
epoch 61: w = 1.898, loss = 0.01514076
epoch 71: w = 1.901, loss = 0.01425949
epoch 81: w = 1.904, loss = 0.01342951
epoch 91: w = 1.907, loss = 0.01264785
Prediction after training: f(5) = 9.813
