In [None]:
%pip install torch
%pip install sklearn

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# Generate data
def generate_data(n_samples=100):
    X = np.linspace(-10, 10, n_samples).reshape(-1, 1)
    y = X**3
    return torch.FloatTensor(X), torch.FloatTensor(y)

X, y = generate_data()

# Define the model
class LinearRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
        # Initialize b at -400
        self.linear.bias.data.fill_(-400)
    
    def forward(self, x):
        return self.linear(x)

# Calculate loss for each w and b
def calculate_loss_grid(X, y, w_range=(-10, 70), b_range=(-500, 500), n_points=100):
    w = np.linspace(w_range[0], w_range[1], n_points)
    b = np.linspace(b_range[0], b_range[1], n_points)
    W, B = np.meshgrid(w, b)
    
    Z = np.zeros_like(W)
    for i in range(n_points):
        for j in range(n_points):
            y_pred = X * W[i, j] + B[i, j]
            Z[i, j] = np.mean((y.numpy() - y_pred.numpy())**2)
    
    return W, B, Z

W, B, Z = calculate_loss_grid(X, y)

# Initialize the model and optimizer
model = LinearRegression()
optimizer = optim.Adam(model.parameters(), lr=0.1)
criterion = nn.MSELoss()

# Training function
def train(model, X, y, max_epochs):
    losses = []
    w_history = []
    b_history = []
    
    for epoch in range(max_epochs):
        optimizer.zero_grad()
        y_pred = model(X)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        w_history.append(model.linear.weight.item())
        b_history.append(model.linear.bias.item())
    
    return losses, w_history, b_history

# Train the model once for the maximum number of epochs
max_epochs = 6000
all_losses, all_w_history, all_b_history = train(model, X, y, max_epochs)

# Visualization function
def visualize(epoch):
    losses = all_losses[:epoch]
    w_history = all_w_history[:epoch]
    b_history = all_b_history[:epoch]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Contour plot
    contour = ax1.contour(W, B, Z, levels=50)
    path = ax1.plot(w_history, b_history, 'ro-', label='Training Path')
    ax1.set_xlabel('w')
    ax1.set_ylabel('b')
    ax1.set_title('Loss Contour and Training Path')
    ax1.legend()
    plt.colorbar(contour, ax=ax1, label='Loss')
    
    # Loss curve with log-scale y-axis
    ax2.semilogy(range(1, epoch+1), losses)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss (log scale)')
    ax2.set_title('Training Loss')
    ax2.set_xlim(0, max_epochs)
    ax2.set_ylim(min(all_losses), max(all_losses))
    # Clear the current output
    clear_output(wait=True)

    display(fig)
    plt.close(fig)

# Create the epoch slider and train button
epoch_slider = widgets.IntSlider(value=1, min=1, max=max_epochs, step=1, description='Epochs:')
train_button = widgets.Button(description='Train')

# Function to handle the train button click
def on_train_button_click(b):

    epoch_initial = epoch_slider.value

    for i in [100, 200, 300, 400, 600, 1000, 1500, 2000, 3000, 4000, 5000, 6000]:
        epoch_slider.value = i
        visualize(i)
    
    # Wait for 3 seconds before clearing output
    import time
    time.sleep(3)

    clear_output(wait=True)

    epoch_slider.value = epoch_initial

    # Create the interactive output
    out = widgets.interactive_output(visualize, {'epoch': epoch_slider})

    # Display the widgets and output
    display(widgets.VBox([out, widgets.HBox([epoch_slider, train_button])]))


# Connect the button click event to the handler function
train_button.on_click(on_train_button_click)

# Create the interactive output
out = widgets.interactive_output(visualize, {'epoch': epoch_slider})

# Display the widgets and output
display(widgets.VBox([out, widgets.HBox([epoch_slider, train_button])]))
