# 阶段二：寻找函数

## 用 PyTorch 来实现刚才的梯度下降

In [None]:
import torch
import numpy as np

In [2]:
from torch.utils.data import Dataset, DataLoader

In [3]:
# 一个拟合函数生成数据库
class SyntheticData(Dataset):
    def __init__(self, generater_func):
        self.generater_func = generater_func
        self.pts_x = np.arange(-4.5, 4.5, 0.2).astype(np.float32)
        self.pts_y = np.array([ generater_func(x) for x in self.pts_x]).astype(np.float32)
        
    def __len__(self):
        return len(self.pts_x)
    
    def __getitem__(self, idx):
        return {'x': self.pts_x[idx], 'y': self.pts_y[idx]}

In [4]:
a_ground_truth = np.double(1.2)
b_ground_truth = np.double(-3.7)
c_ground_truth = np.double(4.9)

target_func = lambda x: a_ground_truth * x * x + b_ground_truth * x + c_ground_truth
data_generate_func = lambda x: target_func(x) + 10.0 * (np.double(np.random.rand()) - 0.5)

In [5]:
sd = SyntheticData(data_generate_func)

In [6]:
data_loader = DataLoader(sd, batch_size=4, shuffle=True, num_workers=0)

In [7]:
# 假设我们的拟合的函数还是 y = a * x * x + b * x + c
class HypoFunc(torch.nn.Module):
    def __init__(self):
        super(HypoFunc, self).__init__()
        self.a = torch.tensor(np.random.rand(), requires_grad=True)
        self.b = torch.tensor(np.random.rand(), requires_grad=True)
        self.c = torch.tensor(np.random.rand(), requires_grad=True)
    
    def forward(self, x):
        y_predict = self.a * x * x + self.b * x + self.c
        return y_predict

In [8]:
hypo_func = HypoFunc()

In [9]:
# 定义损失函数 Mean Square Loss
criterion = torch.nn.MSELoss(reduction='sum')

In [10]:
# PyTorch 提供了一个 optimizer 来实现对 a，b，c 的优化
optimizer = torch.optim.SGD([hypo_func.a, hypo_func.b, hypo_func.c], lr=1e-3)

In [23]:
# 跑起来
for i in range(10000000):
    if i%500000 == 0:
        print([hypo_func.a.data, hypo_func.b.data, hypo_func.c.data])
        for idx, batch in enumerate(data_loader):
            y_pred = hypo_func(batch['x'])
            loss = criterion(y_pred, batch['y'])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

[tensor(1.0674), tensor(-3.9428), tensor(4.5898)]
[tensor(1.2058), tensor(-4.0300), tensor(4.6029)]
[tensor(0.9202), tensor(-3.9367), tensor(4.5876)]
[tensor(1.0701), tensor(-3.9649), tensor(4.6031)]
[tensor(1.1447), tensor(-3.9746), tensor(4.6069)]
[tensor(1.0469), tensor(-3.9676), tensor(4.6046)]
[tensor(1.0018), tensor(-3.9247), tensor(4.6083)]
[tensor(1.2261), tensor(-3.9238), tensor(4.6201)]
[tensor(1.0006), tensor(-3.9832), tensor(4.6009)]
[tensor(1.1956), tensor(-4.0478), tensor(4.6073)]
[tensor(1.1082), tensor(-4.0738), tensor(4.6010)]
[tensor(1.0548), tensor(-4.0584), tensor(4.5953)]
[tensor(1.2350), tensor(-3.9337), tensor(4.6015)]
[tensor(1.0651), tensor(-4.0728), tensor(4.5896)]
[tensor(0.9786), tensor(-3.9532), tensor(4.5932)]
[tensor(1.0009), tensor(-3.9529), tensor(4.5905)]
[tensor(0.9143), tensor(-3.9770), tensor(4.5842)]
[tensor(1.1785), tensor(-4.0488), tensor(4.6034)]
[tensor(1.0238), tensor(-3.9615), tensor(4.5916)]
[tensor(1.0542), tensor(-3.9894), tensor(4.5941)]


In [12]:
print('Checking the ground truth of [a,b,c]')
np.array([1.2, -3.7, 4.9]).astype(np.float64)

Checking the ground truth of [a,b,c]


array([ 1.2, -3.7,  4.9])