In [9]:
import random

def dot(v1, v2):
    return sum([(v1i * v2i) for v1i, v2i in zip(v1, v2)])

def fuzz(std):
    return random.gauss(0, std)

def make_data(true_params, n=100, std=0.7):
    xs = [[random.random() * 3 for _ in range(len(true_params)-1)] + [1] for _ in range(n)]
    ys = [dot(xi, true_params) + fuzz(std) for xi in xs]
    return xs, ys

def vector_subtract(v1, v2):
    return [(v1i - v2i) for v1i, v2i in zip(v1, v2)]

def scalar_product(a, v):
    return [a*vi for vi in v]

def mse(xs, ys, params):
    yhat = [dot(xi, params) for xi in xs]
    squared_errors = [(yhat_i - ys_i)**2 for yhat_i, ys_i in zip(yhat, ys)]
    return sum(squared_errors)/len(ys)

def grad_func_i(xs, yx, params, loss_func, i, h=0.00001):
    params_nudged = [pi + (h if i==j else 0) for j, pi in enumerate(params)]
    return (loss_func(xs, ys, params_nudged) - loss_func(xs, ys, params)) / h
    
def grad_func(xs, yx, params, loss_func):
    return [grad_func_i(xs, yx, params, loss_func, i) for i in range(len(params))]

def sgd(xs, ys, params_0, loss_func, grad_func, lr_0):
    params = params_0
    lr = lr_0
    minloss = float('inf')
    minparams = None
    iterations_without_improvement = 0
    while iterations_without_improvement < 100:
        loss = loss_func(xs, ys, params)
        if loss < minloss:
            minloss = loss
            minparams = params
            iterations_without_improvement = 0
        else:
            iterations_without_improvement += 1
            lr *= 0.9
        grad = grad_func(xs, ys, params, loss_func)
        params = vector_subtract(params, scalar_product(lr, grad))
    return minparams


In [10]:
true_params = [23, 3, -5]
xs, ys = make_data(true_params)
params_0 = [random.random() for _ in range(len(true_params))]
lr_0 = 0.1
params = sgd(xs, ys, params_0, mse, grad_func, lr_0)
print(f"true params: {true_params}")
print(f"init params: {params_0}")
print(f"final params: {params}")

true params: [23, 3, -5]
init params: [0.6362455908836219, 0.6588238402608223, 0.35956321815800296]
final params: [22.90951375344178, 3.0720713997029, -4.950189421925546]
