In [1]:
import torch
import numpy as np
# * 在 PyTorch 世界，所有物件都是張量（tensor）

## Use numpy to display a case

In [8]:
def forward(x):
    return w*x

def loss(y, y_pred):
    return ((y_pred-y)**2).mean()

def gradient(x, y, y_pred):
    return np.dot(2*x, y_pred-y).mean()

In [7]:
X = np.array([1,2,3,4], dtype=np.float32)
Y = np.array([2,4,6,8], dtype=np.float32)
w = 0

learning_rate = 0.01
n_iter = 10

In [9]:
print(f'Prediction before training: f(5) = {forward(5):.3f}')

for epoch in range(n_iter):
    Y_pred = forward(X)
    l = loss(Y,Y_pred)
    dw = gradient(X,Y,Y_pred)
    w -= learning_rate*dw
    if epoch%1 == 0:
        print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.5f}')
print(f'Prediction after training: f(5) = {forward(5):.3f}')

Prediction before training: f(5) = 0.000
epoch 1: w = 1.200, loss = 30.00000
epoch 2: w = 1.680, loss = 4.80000
epoch 3: w = 1.872, loss = 0.76800
epoch 4: w = 1.949, loss = 0.12288
epoch 5: w = 1.980, loss = 0.01966
epoch 6: w = 1.992, loss = 0.00315
epoch 7: w = 1.997, loss = 0.00050
epoch 8: w = 1.999, loss = 0.00008
epoch 9: w = 1.999, loss = 0.00001
epoch 10: w = 2.000, loss = 0.00000
Prediction after training: f(5) = 9.999


## Use torch to display the same case

In [16]:
X = torch.tensor([1,2,3,4], dtype=torch.float32)
Y = torch.tensor([2,4,6,8], dtype=torch.float32)
w = torch.tensor(0, dtype=torch.float32, requires_grad=True)

learning_rate = 0.01
n_iter = 100

In [17]:
print(f'Prediction before training: f(5) = {forward(5):.3f}')

for epoch in range(n_iter):
    # forward pass
    Y_pred = forward(X)
    l = loss(Y,Y_pred)
    # backward pass
    l.backward()
    # update weights
    with torch.no_grad():
        w -= learning_rate*w.grad
    # zero gradient
    w.grad.zero_()
    
    if epoch%1 == 0:
        print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.5f}')
print(f'Prediction after training: f(5) = {forward(5):.3f}')

Prediction before training: f(5) = 0.000
epoch 1: w = 0.300, loss = 30.00000
epoch 2: w = 0.555, loss = 21.67500
epoch 3: w = 0.772, loss = 15.66019
epoch 4: w = 0.956, loss = 11.31449
epoch 5: w = 1.113, loss = 8.17472
epoch 6: w = 1.246, loss = 5.90623
epoch 7: w = 1.359, loss = 4.26725
epoch 8: w = 1.455, loss = 3.08309
epoch 9: w = 1.537, loss = 2.22753
epoch 10: w = 1.606, loss = 1.60939
epoch 11: w = 1.665, loss = 1.16279
epoch 12: w = 1.716, loss = 0.84011
epoch 13: w = 1.758, loss = 0.60698
epoch 14: w = 1.794, loss = 0.43854
epoch 15: w = 1.825, loss = 0.31685
epoch 16: w = 1.851, loss = 0.22892
epoch 17: w = 1.874, loss = 0.16540
epoch 18: w = 1.893, loss = 0.11950
epoch 19: w = 1.909, loss = 0.08634
epoch 20: w = 1.922, loss = 0.06238
epoch 21: w = 1.934, loss = 0.04507
epoch 22: w = 1.944, loss = 0.03256
epoch 23: w = 1.952, loss = 0.02353
epoch 24: w = 1.960, loss = 0.01700
epoch 25: w = 1.966, loss = 0.01228
epoch 26: w = 1.971, loss = 0.00887
epoch 27: w = 1.975, loss = 