In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

class LinearRegressionAnimation:
    def __init__(self, learning_rate=0.01, n_iterations=1000):
        self.learning_rate = learning_rate
        self.n_iterations = n_iterations
        self.history = {'a': [], 'b': [], 'rmse': [], 'r_squared': []}

    def compute_gradients(self, x, y):
        N = len(x)
        y_pred = self.a * x + self.b
        grad_a = (2/N) * np.sum(x * (y_pred - y))
        grad_b = (2/N) * np.sum(y_pred - y)
        return grad_a, grad_b

    def fit(self, x, y):
        self.a = (np.corrcoef(x, y)[0,1]) * (np.std(x)/np.std(y))
        self.b = (np.mean(y))*(self.a*(np.mean(x)))
        for _ in range(self.n_iterations):
            grad_a, grad_b = self.compute_gradients(x, y)
            # Update parameters using gradients
            self.a -= self.learning_rate * grad_a
            self.b -= self.learning_rate * grad_b
            # Compute RMSE and R-squared for evaluation
            y_pred = self.a * x + self.b
            rmse = np.sqrt(np.mean((y_pred - y)**2))
            r_squared = 1 - (np.sum((y - y_pred)**2) / np.sum((y - np.mean(y))**2))
            self.history['a'].append(self.a)
            self.history['b'].append(self.b)
            self.history['rmse'].append(rmse)
            self.history['r_squared'].append(r_squared)

    def animate_gradient_descent(self, x, y):
        fig, ax = plt.subplots()
        ax.scatter(x, y, color='blue')
        line, = ax.plot([], [], lw=2, color='red')
        annotation = ax.text(-1, 100, '')

        def init():
            line.set_data([], [])
            annotation.set_text('')
            return line, annotation

        def animate(i):
            x_vals = np.linspace(min(x), max(x), 100)
            y_vals = self.history['a'][i] * x_vals + self.history['b'][i]
            line.set_data(x_vals, y_vals)
            annotation.set_text('Iteration: {}'.format(i))
            return line, annotation

        anim = FuncAnimation(fig, animate, init_func=init, frames=len(self.history['a']), interval=100, blit=True)
        plt.close()  # Close the plot to prevent it from displaying initially
        return HTML(anim.to_jshtml())


In [None]:
# Example usage:
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 4, 6, 8, 10])

model = LinearRegressionAnimation()
model.fit(x, y)

# Animate gradient descent
model.animate_gradient_descent(x, y)