In [8]:
import random

def fuzz(k):
    return k * random.gauss(0,1)

def dot(v1, v2):
    return sum((v1i * v2i) for v1i, v2i in zip(v1, v2))

def make_data(true_params, n=100, k=0.7):
    x = [[random.uniform(-3,3) for _ in range(len(true_params)-1)] + [1] for _ in range(n)]
    y = [dot(xi, true_params) + fuzz(k) for xi in x]
    return x, y

def vector_subtract(v1, v2):
    return [(v1i - v2i) for v1i, v2i in zip(v1, v2)]

def scalar_multiply(a, v):
    return [a * vi for vi in v]

def mse(x, y, params):
    yhat = [dot(xi, params) for xi in x]
    se = [(yi - yhati)**2 for yi, yhati in zip(y, yhat)]
    return sum(se)/len(se)

def grad_est_i(x, y, params, loss_fn, i, h=0.000001):
    params_nudged = [pi + (h if i == ix else 0) for ix, pi in enumerate(params)]
    return (loss_fn(x, y, params_nudged) - loss_fn(x, y, params)) / h
    
def grad_est(x, y, params, loss_fn):
    return [grad_est_i(x, y, params, loss_fn, i) for i, _ in enumerate(params)]

def sgd(x, y, params_0, loss_fn, grad_fn, lr_0):
    params = params_0
    lr = lr_0
    min_loss = float('inf')
    min_params = None
    iterations_without_improvement = 0
    while iterations_without_improvement < 100:
        loss = loss_fn(x, y, params)
        if loss < min_loss:
            min_loss = loss
            min_params = params
            iterations_without_improvement = 0
        else:
            iterations_without_improvement += 1
            lr *= 0.9
        grad = grad_fn(x, y, params, loss_fn)
        params = vector_subtract(params, scalar_multiply(lr, grad))
    return params


In [9]:
true_params = [23, 3, -5]
x, y = make_data(true_params)
params_0 = [random.random() for _ in range(len(true_params))]
params_f = sgd(x, y, params_0, mse, grad_est, 0.01)
print(f"true params: {true_params}")
print(f"init params: {params_0}")
print(f"final params: {params_f}")

true params: [23, 3, -5]
init params: [0.44599619397123724, 0.33438895173115224, 0.7787017800836846]
final params: [22.979281765174903, 3.025691445578364, -4.855225595926771]
