In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import itertools

plt.rcParams['figure.figsize'] = [14, 6]

In [None]:
class Func:
    def __init__(self, params) -> None:
        self.params = np.array(params, dtype=np.float64)

    def grad(self, X):
        raise NotImplementedError()

class LinearMSE(Func):
    def __call__(self, X, Y):
        return np.mean(np.square(Y - self.pred(X))) / 2

    def pred(self, X):
        return self.params[0] + self.params[1] * X

    def grad(self, X, Y):
        X = np.reshape(X, (1, -1))
        err = self.pred(X) - Y
        
        tt = np.concatenate([err, err * X], axis=0)

        return np.mean(tt, axis=1)

In [None]:
def fit(func: LinearMSE, X, Y, num_epochs: int, learning_rate: float):
    params_history = [func.params.copy()]
    grad_history = []
    loss_history = []

    for _ in tqdm(range(num_epochs)):
        grad = func.grad(X, Y)

        func.params -= grad * learning_rate

        params_history.append(func.params.copy())
        grad_history.append(grad)
        loss_history.append(func(X, Y))
    
    return np.array(params_history), np.array(grad_history), np.array(loss_history)

def fit_sgd(func: LinearMSE, X, Y, batch_size: float, num_epochs: int, learning_rate: float):
    __X = X.copy()
    __Y = Y.copy()
    num_samples = X.shape[0]

    params_history = [func.params.copy()]
    grad_history = []
    loss_history = []

    for _ in tqdm(range(num_epochs)):
        perm = np.random.permutation(num_samples)
        __X, __Y = __X[perm], __Y[perm]

        for batch_start in range(0, num_samples, batch_size):
            x = __X[batch_start:batch_start + batch_size]
            y = __Y[batch_start:batch_start + batch_size]

            grad = func.grad(x, y)

            func.params -= grad * learning_rate

            params_history.append(func.params.copy())
            grad_history.append(grad)
            loss_history.append(func(x, y))
    
    return np.array(params_history), np.array(grad_history), np.array(loss_history)

In [None]:
WINDOW_START = 0
WINDOW_END = 5
WINDOW_COUNT = 100

DATA_B_0 = 8
DATA_B_1 = -4

VIZ_PARAM_MIN = -10
VIZ_PARAM_MAX = 10

In [None]:
X = np.linspace(WINDOW_START, WINDOW_END, WINDOW_COUNT)

Y = DATA_B_1 * X + DATA_B_0

In [None]:
func = LinearMSE([-7, 1])


ph, gh, lh = fit(func, X, Y, num_epochs=100, learning_rate=1e-1)
# ph, gh, lh = fit_sgd(func, X, Y, batch_size=10, num_epochs=10, learning_rate=1e-1)

f_series = func.pred(X)


In [None]:
ph_len = ph.shape[0] - 1
for pi in np.linspace(0, ph_len, 7):
    print(f'Line #{int(pi)} params:', ph[int(pi)])
    plt.plot(X, LinearMSE(ph[int(pi)]).pred(X), 'b', alpha=pi/ph_len)

plt.plot(X, Y, 'r--')
plt.show()

In [None]:
plt.subplot(221)
plt.plot(ph[:, 0])
plt.title('Param 0')
plt.subplot(222)
plt.plot(gh[:, 0])
plt.title('Gradient 0')
plt.subplot(223)
plt.plot(ph[:, 1])
plt.title('Param 1')
plt.subplot(224)
plt.plot(gh[:, 1])
plt.title('Gradient 1')
plt.show()

In [None]:
plt.plot(lh)
plt.title('Loss')
plt.show()

In [None]:
p0s = np.linspace(VIZ_PARAM_MIN, VIZ_PARAM_MAX, 200)
p1s = np.linspace(VIZ_PARAM_MIN, VIZ_PARAM_MAX, 200)

ps = itertools.product(p0s, p1s)

f = [[LinearMSE([p0, p1])(X, Y) for p1 in p1s] for p0 in p0s]

plt.imshow(f, extent=[VIZ_PARAM_MIN, VIZ_PARAM_MAX, VIZ_PARAM_MIN, VIZ_PARAM_MAX])