In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn

In [None]:
X = torch.linspace(1.0, 50., 50).reshape(-1, 1)
X

In [None]:
torch.manual_seed(71)
e = torch.randint(-8, 9, (50, 1), dtype=torch.float)
e

In [None]:
y = 2*X + 1 + e

In [None]:
plt.scatter(X.numpy(), y.numpy())

In [None]:
class Model(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
    

In [None]:
torch.manual_seed(59)
model = Model(1, 1)

In [None]:
for name, param in model.named_parameters():
    print(name, '\t', param.item())

In [None]:
# check the currently non-trained model's forward function on a single input
x = torch.tensor([2.0])
model.forward(x)
# 2.0 * 0.10597813129425049 + 0.9637961387634277 = 1.1758

In [None]:
# check the currently non-trained model performance
x1 = np.linspace(0.0, 50.0, 50)
w1 = model.linear.weight.item()
b1 = model.linear.bias.item()
y1 = w1 * x1 + b1
y1

In [None]:
plt.plot(x1, y1, 'r', label='non-trained model predictions')
plt.scatter(X.numpy(), y.numpy(), label='input data')
plt.legend()

In [None]:
# train the model and check the performance again
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [None]:
epochs = 50
losses = []

for i in range(50):
    i += 1
    
    y_pred = model.forward(X)
    loss = criterion(y_pred, y)
    losses.append(loss)
    
    print(f"epoch {i}, loss:{loss.item()}, weight:{model.linear.weight.item()}, bias:{model.linear.bias.item()}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
plt.plot(range(epochs), losses)
plt.ylabel('MSE LOSS')
plt.xlabel('Epoch')

In [None]:
x2 = np.linspace(0.0, 50.0, 50)
w2 = model.linear.weight.item()
b2 = model.linear.bias.item()
predicted_y = w2 * x2 + b2

In [None]:
plt.plot(x2, predicted_y, 'r', label='well-trained model predictions')
plt.scatter(X.numpy(), y.numpy(), label='input data')
plt.legend()