In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
X = torch.tensor([[1],[2],[3],[4]], dtype = torch.float32)

In [4]:
Y = torch.tensor([[2],[4],[6],[8]], dtype = torch.float32) # since y = 2x

In [5]:
n_samples, n_features = X.shape

### Forward pass

In [6]:
class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        self.lin = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.lin(x)

In [7]:
X_test = torch.tensor([5], dtype = torch.float32)
input_size = n_features
output_size = n_features
model = LinearRegression(input_size, output_size)

### Loss MSE

In [8]:
learning_rate = 0.01
n_iters = 200

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

### Training

In [9]:

print(f"prediction after training: f(X_test) = {model(X_test).item():.3f}")

for epoch in range(n_iters):
    
    # prediction = Forward pass
    y_pred = model(X)
    
    # loss
    l = loss(Y,y_pred)
    
    # gradients
    l.backward() # dl/dw
    
    # update weights
    optimizer.step()
    
    # zero gradients
    optimizer.zero_grad()
    
    if epoch % 10 == 0:
        [w,b] = model.parameters()
        print(f"epoch {epoch + 1}: w = {w[0][0].item():.3f}, loss = {l:.8f}")
        
print(f"prediction after training: f(X_test) = {model(X_test).item():.3f}")

prediction after training: f(X_test) = -0.439
epoch 1: w = 0.144, loss = 31.19141388
epoch 11: w = 1.430, loss = 0.95677459
epoch 21: w = 1.645, loss = 0.16581404
epoch 31: w = 1.687, loss = 0.13713953
epoch 41: w = 1.701, loss = 0.12866515
epoch 51: w = 1.711, loss = 0.12116350
epoch 61: w = 1.720, loss = 0.11411076
epoch 71: w = 1.728, loss = 0.10746887
epoch 81: w = 1.736, loss = 0.10121370
epoch 91: w = 1.744, loss = 0.09532255
epoch 101: w = 1.751, loss = 0.08977429
epoch 111: w = 1.759, loss = 0.08454892
epoch 121: w = 1.766, loss = 0.07962776
epoch 131: w = 1.773, loss = 0.07499303
epoch 141: w = 1.779, loss = 0.07062800
epoch 151: w = 1.786, loss = 0.06651714
epoch 161: w = 1.792, loss = 0.06264545
epoch 171: w = 1.798, loss = 0.05899919
epoch 181: w = 1.804, loss = 0.05556513
epoch 191: w = 1.810, loss = 0.05233096
prediction after training: f(X_test) = 9.619
