In [111]:
import random
import torch
import matplotlib.pyplot as plt

In [112]:
def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 2, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.31, y.shape)
    return X, y.reshape((-1, 1))


true_w = torch.tensor([-4.1, -2.6])
true_b = 5.9
X_train, y_train = synthetic_data(true_w, true_b, 1000)

X_train

tensor([[ 2.9308, -5.2420],
        [ 1.7078, -0.4133],
        [-1.4748,  1.5687],
        ...,
        [ 0.7888,  2.3998],
        [ 0.1581, -0.2064],
        [-0.0521, -1.2057]])

In [113]:
# generator function
# the merit of using genrator function (not using return, instead use yield): small memory usage
def data_iter(batch_size, X_train, y_train):
    num_examples = len(X_train)
    indicies = list(range(num_examples))
    random.shuffle(indicies)
    for i in range(0, num_examples, batch_size):
        batch_indicies = torch.tensor(indicies[i : min(i + batch_size, num_examples)])
        yield X_train[batch_indicies], y_train[batch_indicies]

In [114]:
def linreg(X, w, b):
    return torch.matmul(X, w) + b

In [115]:
def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

In [122]:
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

In [123]:
batch_size = 10
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
lr = 0.01
num_epochs = 10
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, X_train, y_train):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w, b], lr, batch_size)
    with torch.no_grad():
        train_l = loss(net(X_train, w, b), y_train)
        print(f"epoch {epoch + 1}, loss {float(train_l.mean()):.5f}")

epoch 1, loss 2.58357
epoch 2, loss 0.38989
epoch 3, loss 0.09243
epoch 4, loss 0.05311
epoch 5, loss 0.04784
epoch 6, loss 0.04691
epoch 7, loss 0.04691
epoch 8, loss 0.04689
epoch 9, loss 0.04699
epoch 10, loss 0.04692
