In [1]:
# 1. model : input, output size, forward pass
# 2. loss & optimizer
# 3. training loop : forward, backward, update weights

In [2]:
import torch
import torch.nn as nn

In [3]:
# f = w * x
x = torch.tensor([[1], [2], [3], [4]], dtype=torch.float)
y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float)
x_test = torch.tensor([5], dtype=torch.float)

# model
n_samples, n_features = x.shape

intput_size = n_features
output_size = n_features
model = nn.Linear(in_features=intput_size, out_features=output_size)

class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        
        # define layers
        self.lin = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.lin(x)

model = LinearRegression(intput_size, output_size)

print(f'prediction before training : f(5) = {model(x_test).item():.3f}')

# training
learning_rate = 0.01
n_iters = 100

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(n_iters):
    # prediction = forward pass
    y_pred = model(x)
    
    # loss
    l = loss(y, y_pred)
    
    # gradients = backward pass
    l.backward() # dl/dw
    
    # update weights
    optimizer.step()
    
    # zero gradients
    optimizer.zero_grad()
    
    if epoch % 10 == 0:
        [w, b] = model.parameters()
        print(f'[epoch {epoch+1:2d}] w = {w[0][0]:.3f} / loss = {l:.8f}')

print(f'prediction after training : f(5) = {model(x_test).item():.3f}')

prediction before training : f(5) = -2.441
[epoch  1] w = -0.054 / loss = 47.82585907
[epoch 11] w = 1.534 / loss = 1.27479768
[epoch 21] w = 1.793 / loss = 0.06823235
[epoch 31] w = 1.839 / loss = 0.03496389
[epoch 41] w = 1.850 / loss = 0.03217083
[epoch 51] w = 1.855 / loss = 0.03027873
[epoch 61] w = 1.860 / loss = 0.02851580
[epoch 71] w = 1.864 / loss = 0.02685604
[epoch 81] w = 1.868 / loss = 0.02529286
[epoch 91] w = 1.872 / loss = 0.02382073
prediction after training : f(5) = 9.743
