In [1]:
def predict(x, w, b):
    """
    预测函数
    x: [x1, x2, x3]
    w: [w1, w2, w3]
    """
    return sum([x[i] * w[i] for i in range(len(x))]) + b


def gradient(x, y, w, b):
    """
    梯度下降
    """
    m = len(x)  # 样本数量
    n = len(x[0])  # 特征数量
    gradient_w = [0.0] * n
    for j in range(n):
        sum = 0.0
        for i in range(m):
            sum += (predict(x[i], w, b) - y[i]) * x[i][j]
        gradient_w[j] = sum / m

    sum = 0.0
    for i in range(m):
        sum += predict(x[i], w, b) - y[i]
    gradient_b = sum / m

    return gradient_w, gradient_b


def cost(x, y, w, b):
    """
    损失函数
    """
    m = len(x)  # 样本数量
    sum = 0.0
    for i in range(m):
        sum += (predict(x[i], w, b) - y[i]) ** 2
    return sum / (2 * m)


x = [
    [1.0, 2.0, 3.0],
    [2.0, 3.0, 4.0],
    [3.0, 4.0, 5.0],
    [4.0, 5.0, 6.0]
]

y = [14.0, 20, 26, 32]

lr = 0.001
iteration = 10000
n = len(x[0])  # 特征数量
w = [0.0] * n
b = 0.0

for i in range(iteration):
    gradient_w, gradient_b = gradient(x, y, w, b)
    w = [w[j] - lr * gradient_w[j] for j in range(n)]
    b = b - lr * gradient_b
    if i % 50 == 0:
        print("cost: ", cost(x, y, w, b))

print("w: ", w)
print("b: ", b)

cost:  262.7175845
cost:  3.168176338097849
cost:  0.04529557381603542
cost:  0.007539177060765904
cost:  0.006905272682971362
cost:  0.006722610908951459
cost:  0.006549873980303309
cost:  0.006381636835677789
cost:  0.0062177216931842885
cost:  0.006058016791930836
cost:  0.0059024139810264185
cost:  0.00575080789637713
cost:  0.005603095880320429
cost:  0.005459177912001047
cost:  0.005318956539643737
cost:  0.005182336814564085
cost:  0.0050492262268767235
cost:  0.004919534642852859
cost:  0.004793174243888027
cost:  0.00467005946703682
cost:  0.0045501069470761605
cost:  0.004433235460054354
cost:  0.004319365868292952
cost:  0.004208421066799128
cost:  0.00410032593105588
cost:  0.003995007266151841
cost:  0.003892393757219237
cost:  0.0037924159211438113
cost:  0.003695006059515358
cost:  0.0036000982127873064
cost:  0.003507628115612346
cost:  0.0034175331533269292
cost:  0.003329752319552886
cost:  0.0032442261748871154
cost:  0.0031608968066538312
cost:  0.003079707789689362