In [120]:
import numpy as np
import torch
from torch.utils import data

true_w = torch.tensor([2, -3.4])
true_b = 4.2

In [121]:
# first create dataset manually
def create_dataset(W: torch.Tensor, b: torch.Tensor, num_sample: int) -> torch.Tensor:
    # W: (2), b: int, num_sample: int
    data = torch.normal(0, 1, (num_sample, W.shape[0])) # (num_sample, 2)
    labels = data @ W + b # (num_sample, 2) * (2) = (num_sample)
    labels += torch.normal(0, 0.01, labels.shape)
    print(data.shape, labels.shape)
    return data, labels.reshape(-1, 1) # important to reshape labels to (num_sample, 1)

In [122]:
features, labels = create_dataset(true_w, true_b, 1000)

torch.Size([1000, 2]) torch.Size([1000])


In [123]:
def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

[tensor([[-0.1184, -0.0703],
         [ 3.1284,  0.9363],
         [-0.4152, -0.8787],
         [ 1.6363, -1.3469],
         [-0.6095,  0.9064],
         [-0.9099,  0.9007],
         [ 1.0098, -0.4392],
         [ 1.8620,  0.4801],
         [ 0.7616, -0.0541],
         [ 0.4931,  1.0736]]),
 tensor([[ 4.1957],
         [ 7.2681],
         [ 6.3605],
         [12.0565],
         [-0.1015],
         [-0.6797],
         [ 7.7148],
         [ 6.2954],
         [ 5.8920],
         [ 1.5287]])]

In [124]:
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))

In [125]:
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

tensor([0.])

In [126]:
loss = nn.MSELoss()

In [127]:
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

In [128]:
num_epochs = 10
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

epoch 1, loss 0.000205
epoch 2, loss 0.000103
epoch 3, loss 0.000102
epoch 4, loss 0.000103
epoch 5, loss 0.000101
epoch 6, loss 0.000101
epoch 7, loss 0.000102
epoch 8, loss 0.000103
epoch 9, loss 0.000102
epoch 10, loss 0.000102
