In [14]:
import numpy as np

In [15]:
# 采样数据
data = []
for i in range(100):
    # 随机采样输入x
    x = np.random.uniform(-10., 10.)
    # 采样高斯噪声
    eps = np.random.normal(0., 0.1)
    # 得到模型的输出
    y = 1.477 * x + 0.089 + eps
    data.append([x, y])
data = np.array(data)

In [16]:
# 计算误差
def mse(b, w, points):
    # 根据当前的w,b来计算均方差损失
    total_error = 0
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        total_error += (y - (w * x + b)) ** 2
    # 将累加的误差求平均值，得到均方差
    return total_error / float(len(points))

In [17]:
# 计算梯度
def step_gradient(b_current, w_current, points, lr):
    # 计算误差函数在所有点上的导数，并更新w，b
    b_gradient = 0
    w_gradient = 0
    M = float(len(points)) # 总样本数
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        #误差函数对b的导数：grad_b = 2(wx+b-y)
        b_gradient += (2/M) * ((w_current * x + b_current) - y)
        # 误差函数对w的导数：grad_w = 2(wx+b-y)*x
        b_gradient += (2/M) * x * ((w_current * x + b_current) - y)
    # 根据梯度下降算法更新w',b',其中lr为学习率
    new_b = b_current - (lr * b_gradient)
    new_w = w_current - (lr * w_gradient)
    return [new_b, new_w]
        

In [18]:
# 梯度更新
def gradient_descent(points, starting_b, starting_w, lr, num_iterations):
    b = starting_b
    w = starting_w
    for step in range(num_iterations):
        b, w = step_gradient(b, w, np.array(points), lr)
        loss = mse(b, w, points)    # 计算当前的均方差，用于监控训练进度
        if step % 50 == 0:
            print(f"iteration:{step}, loss:{loss}, w:{w}, b:{b}")
    return [b, w]

In [19]:
# 主训练函数
def main():
    lr = 0.01
    initial_b = 0
    initial_w = 0
    num_iterations = 1000
    # 训练优化1000次，返回最优w*,b*和训练Loss的下降过程
    [b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
    loss = mse(b, w, data)
    print(f"Final loss:{loss}, w:{w}, b:{b}")

In [20]:
main()

iteration:0, loss:73.78169078200182, w:0.0, b:0.8949891929029106
iteration:50, loss:3927.741761635103, w:0.0, b:59.988262323326786
iteration:100, loss:26256.689030749556, w:0.0, b:159.6728097350446
iteration:150, loss:108963.31441324728, w:0.0, b:327.83084983458656
iteration:200, loss:376648.45685027615, w:0.0, b:611.4969463794081
iteration:250, loss:1192924.3720723707, w:0.0, b:1090.013788519953
iteration:300, loss:3607755.134333578, w:0.0, b:1897.2246995213648
iteration:350, loss:10634692.29983371, w:0.0, b:3258.91021914692
iteration:400, loss:30892622.739690714, w:0.0, b:5555.939940979648
iteration:450, loss:88981005.53505802, w:0.0, b:9430.803458977449
iteration:500, loss:255024519.12837255, w:0.0, b:15967.317610448972
iteration:550, loss:728780972.8679037, w:0.0, b:26993.774943445474
iteration:600, loss:2079040229.7218266, w:0.0, b:45594.325404889976
iteration:650, loss:5924965107.160482, w:0.0, b:76971.62701865508
iteration:700, loss:16875106528.565136, w:0.0, b:129902.0528467480