In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import Dropdown, FloatSlider, IntSlider
from ipywidgets import interactive, VBox, Output, Label

# For reproducibility
np.random.seed(42)


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


# --- Define the parabola function ---
def f(x):
    return x**2  # simple convex function


def df(x):
    return 2*x  # derivative


# --- Optimization routine ---
def optimize(optimizer="GD", lr=0.1, epochs=20, momentum=0.9):
    x = -8
    v = 0  # momentum buffer
    history = [x]

    for i in range(epochs):
        grad = df(x)

        if optimizer == "GD":
            x = x - lr * grad

        elif optimizer == "Momentum":
            v = momentum * v - lr * grad
            x = x + v

        history.append(x)

    return np.array(history)

# --- Plotting function ---
def train_and_plot(optimizer, lr, epochs, momentum):
    xs = np.linspace(-11, 11, 500)
    ys = f(xs)

    path = optimize(optimizer, lr, epochs, momentum)
    path_y = f(path)

    plt.figure(figsize=(8,6))
    plt.plot(xs, ys, label="f(x) = xÂ²", color="blue", alpha=0.5)
    plt.plot(path, path_y, color="red", linestyle="--", alpha=0.7, zorder=3)
    plt.scatter(path[0], path_y[0], s=80, label="Start", marker="*", color="black", zorder=4)
    plt.scatter(path[1:-1], path_y[1:-1], s=60, label="Steps", marker="o", facecolors='none', edgecolors='black', zorder=4)
    plt.scatter(path[-1], path_y[-1], color="red", s=80, label="Final Step", marker="x", zorder=5)

    plt.title(f"Params: LR={lr:.2f}, Epochs={epochs}, Momentum={momentum:.2f}" if optimizer=="Momentum" else f"{optimizer}: LR={lr:.2f}, Epochs={epochs}")
    plt.xlabel("x")
    plt.ylabel("f(x)")
    plt.legend()
    plt.xlim([-12.5, 12.5])
    plt.ylim([-5, 105])
    plt.grid(True)
    plt.show()

# --- Widgets ---
# Define widgets with default values
optimizer_widget = Dropdown(options=["GD", "Momentum"], value="GD", description="Optimizer")

# Store defaults for reset
defaults = {
    "lr": 0.05,
    "epochs": 10,
    "momentum": 0.75
}

lr_widget = FloatSlider(value=defaults["lr"], min=0.01, max=1.1, step=0.01, description="Learning rate", readout_format=".2f")
epochs_widget = IntSlider(value=defaults["epochs"], min=5, max=100, step=5, description="Epochs")
momentum_widget = FloatSlider(value=defaults["momentum"], min=0.1, max=0.99, step=0.05, description="Momentum", readout_format=".2f")
momentum_widget.layout.visibility = 'hidden'

def on_optimizer_change(change):
    if change["name"] == "value":
        lr_widget.value = defaults["lr"]
        epochs_widget.value = defaults["epochs"]
        momentum_widget.value = defaults["momentum"]

        if change["new"] == "Momentum":
            momentum_widget.layout.visibility = 'visible'
        else:
            momentum_widget.layout.visibility = 'hidden'

optimizer_widget.observe(on_optimizer_change, names="value")

label = Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 0 0"))

# Add left margin (indent) to the widgets you want to indent
margin_left = 40
optimizer_widget.layout.margin = f"0 0 0 {margin_left}px"
lr_widget.layout.margin = f"0 0 0 {margin_left}px"
epochs_widget.layout.margin = f"0 0 0 {margin_left}px"
momentum_widget.layout.margin = f"0 0 0 {margin_left}px"

ui = VBox([label, optimizer_widget, lr_widget, epochs_widget, momentum_widget])
out = Output()

def wrapped_train_and_plot(optimizer, lr, epochs, momentum):
    with out:
        out.clear_output(wait=True)
        train_and_plot(optimizer, lr, epochs, momentum)

def gradient_descent_interact():
    plot_description("Gradient descent (GD) demonstration. Use the dropdown to switch between standard GD and GD+Momentum."
                    " Adjust learning rate, number of epochs and momentum. Notice how the process diverges when learning rate > 1.")
    
    interactive_plot = interactive(
        wrapped_train_and_plot,
        optimizer=optimizer_widget,
        lr=lr_widget,
        epochs=epochs_widget,
        momentum=momentum_widget,
    )
    
    display(ui, out)

NameError: name 'Layout' is not defined