In [20]:
import math, copy
import numpy as np


In [21]:
# 加载训练数据
x_train = np.array([1.0, 2.0])  #features
y_train = np.array([300.0, 500.0])  #target value

定义一个代价计算函数

In [22]:
def compute_cost(x, y, w, b):
    m = x.shape[0]
    cost = 0

    for i in range(m):
        f_wb = w * x[i] + b
        cost = cost + (f_wb - y[i]) ** 2
    total_cost = cost / (2 * m)

    return total_cost

定义偏导数计算函数

In [23]:
def compute_gradient(x, y, w, b):
    # 首先要知道样本数目
    m = x.shape[0]
    dj_dw = 0
    dj_db = 0
    # 针对每个样本计算两个偏导数
    for i in range(m):
        f_wb = w * x[i] + b
        tmp_dj_dw = (f_wb - y[i]) * x[i]
        tmp_dj_db = f_wb - y[i]
        dj_dw = dj_dw + tmp_dj_dw
        dj_db = dj_db + tmp_dj_db
    dj_dw /= m
    dj_db /= m

    return dj_dw, dj_db

定义梯度下降函数，用于寻找最小的w和b

In [24]:
def gradient_descent(x, y, w_in, b_in, alpha, num_iter, cost_function, gradient_function):
    # 深度拷贝w
    w = copy.deepcopy(w_in)
    # 拷贝和初始化其他值
    b = b_in
    j_history = []  # 保存代价历史
    p_history = []  # 保存参数历史

    for i in range(num_iter):
        dj_dw, dj_db = gradient_function(x, y, w, b)
        w = w - alpha * dj_dw
        b = b - alpha * dj_db

        if i < 100000:
            j_history.append(cost_function(x, y, w, b))
            p_history.append([w, b])

        if i % math.ceil(num_iter / 10) == 0:
            print(f"Iteration {i:4}: Cost {j_history[-1]:0.2e} ",
                  f"dj_dw: {dj_dw: 0.3e}, dj_db: {dj_db: 0.3e}  ",
                  f"w: {w: 0.3e}, b:{b: 0.5e}")

    return w, b, j_history, p_history

In [27]:
w_init = 0
b_init = 0

iterations = 10000
tmp_alpha = 1.0e-2

w_final, b_final, J_hist, p_hist = gradient_descent(x_train, y_train, w_init, b_init, tmp_alpha,
                                                    iterations, compute_cost, compute_gradient)
print(f"(w,b) found by gradient descent: ({w_final:8.4f},{b_final:8.4f})")

Iteration    0: Cost 7.93e+04  dj_dw: -6.500e+02, dj_db: -4.000e+02   w:  6.500e+00, b: 4.00000e+00
Iteration 1000: Cost 3.41e+00  dj_dw: -3.712e-01, dj_db:  6.007e-01   w:  1.949e+02, b: 1.08228e+02
Iteration 2000: Cost 7.93e-01  dj_dw: -1.789e-01, dj_db:  2.895e-01   w:  1.975e+02, b: 1.03966e+02
Iteration 3000: Cost 1.84e-01  dj_dw: -8.625e-02, dj_db:  1.396e-01   w:  1.988e+02, b: 1.01912e+02
Iteration 4000: Cost 4.28e-02  dj_dw: -4.158e-02, dj_db:  6.727e-02   w:  1.994e+02, b: 1.00922e+02
Iteration 5000: Cost 9.95e-03  dj_dw: -2.004e-02, dj_db:  3.243e-02   w:  1.997e+02, b: 1.00444e+02
Iteration 6000: Cost 2.31e-03  dj_dw: -9.660e-03, dj_db:  1.563e-02   w:  1.999e+02, b: 1.00214e+02
Iteration 7000: Cost 5.37e-04  dj_dw: -4.657e-03, dj_db:  7.535e-03   w:  1.999e+02, b: 1.00103e+02
Iteration 8000: Cost 1.25e-04  dj_dw: -2.245e-03, dj_db:  3.632e-03   w:  2.000e+02, b: 1.00050e+02
Iteration 9000: Cost 2.90e-05  dj_dw: -1.082e-03, dj_db:  1.751e-03   w:  2.000e+02, b: 1.00024e+02
