In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from ipywidgets import interactive, FloatSlider, IntSlider, VBox, HBox, Layout, Output, Label, Dropdown
from IPython.display import display, clear_output

# For reproducibility
np.random.seed(42)


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


LOSS_A_LABEL = "(A) x^2 + y^2"
LOSS_B_LABEL = "(B) 0.1 * x^2 + y^2"
LOSS_C_LABEL = "(C) 10 * x^2 + y^2"


# Global output widget for the 3D plot
output_2d = Output(layout=Layout(width='50%', height='auto'))
output_3d = Output(layout=Layout(width='50%', height='auto'))


def quadratic_loss(x, y, conditioning=1.0):
    """A quadratic loss function with adjustable conditioning."""
    return 0.5 * (conditioning * x**2 + y**2)


def gradient_descent(start_x, start_y, step_size, epochs, conditioning):
    """Simulate gradient descent on the quadratic loss function."""
    x, y = start_x, start_y
    path = [(x, y)]
    for _ in range(epochs):
        grad_x = conditioning * x
        grad_y = y
        x -= step_size * grad_x
        y -= step_size * grad_y
        path.append((x, y))
    return np.array(path)


def plot_2d_convergence(step_size, epochs, loss_function=LOSS_A_LABEL):
    """Plot the 2D contour and gradient descent path."""
    
    if loss_function == LOSS_A_LABEL:
        conditioning = 1.0
    elif loss_function == LOSS_B_LABEL:
        conditioning = 0.1
    else:
        conditioning = 10.0

    with output_2d:
        clear_output(wait=True)
        x = np.linspace(-2, 2, 100)
        y = np.linspace(-2, 2, 100)
        X, Y = np.meshgrid(x, y)
        Z = quadratic_loss(X, Y, conditioning)

        fig, ax = plt.subplots(figsize=(7, 6))
        contour = ax.contour(X, Y, Z, levels=20, cmap="viridis")
        plt.clabel(contour, inline=True, fontsize=8)

        # Gradient descent path
        path = gradient_descent(1.8, 1.5, step_size, epochs, conditioning)
        ax.plot(0, 0, 'bx', markersize=10, markeredgewidth=3, label="Minimum (x: 0, y: 0)")
        ax.plot(path[:, 0], path[:, 1], "ro-", markersize=4, linewidth=1, label="Optimization Path")
        ax.plot(path[0, 0], path[0, 1], "*", markersize=10, color="black", label="Start")

        ax.set_title(f"2D Loss Contour and Optimization Path (\"Top View\")")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_xlim([-2, 2])
        ax.set_ylim([-2, 2])
        ax.legend()
        ax.grid(True)
        plt.show()


def plot_3d_loss(loss_function=LOSS_A_LABEL):
    """Plot the 3D surface of the quadratic loss function."""

    if loss_function == LOSS_A_LABEL:
        conditioning = 1.0
    elif loss_function == LOSS_B_LABEL:
        conditioning = 0.1
    else:
        conditioning = 10.0

    with output_3d:
        clear_output(wait=True)
        x = np.linspace(-2, 2, 100)
        y = np.linspace(-2, 2, 100)
        X, Y = np.meshgrid(x, y)
        Z = quadratic_loss(X, Y, conditioning)

        fig = plt.figure(figsize=(7, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.8)
        ax.plot(0, 0, "ro", markersize=12, markeredgewidth=3, label="Minimum (x: 0, y: 0)")

        ax.set_title(f"3D Loss Function [{loss_function}]")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("Loss")
        ax.set_box_aspect(None, zoom=0.87)  # avoid cutting off z-axis label
        plt.tight_layout()
        plt.legend()
        plt.show()


def update_2d_plot(change):
    if change['name'] == 'value':
        plot_2d_convergence(step_slider.value, epochs_slider.value, loss_function_dropdown.value)


def update_3d_plot(change):
    """Update the 3D plot only when conditioning changes.""" 
    if change['name'] == 'value':
        plot_3d_loss(loss_function_dropdown.value)


label = Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 8px 0"))

# Sliders
step_slider = FloatSlider(value=0.1, min=0.01, max=0.5, step=0.01, description="Learning Rate", layout=Layout(width="300px"))
epochs_slider = IntSlider(value=20, min=1, max=100, step=1, description="Epochs", layout=Layout(width="300px"))
loss_function_dropdown = Dropdown(options=[LOSS_A_LABEL, LOSS_B_LABEL, LOSS_C_LABEL], value=LOSS_A_LABEL, description="Loss func.", layout=Layout(width="300px"))

# Add left margin (indent) to the widgets you want to indent
margin_left = 40
step_slider.layout.margin = f"0 0 0 {margin_left}px"
epochs_slider.layout.margin = f"0 0 0 {margin_left}px"
loss_function_dropdown.layout.margin = f"0 0 0 {margin_left}px"

# Observe changes in conditioning
step_slider.observe(update_2d_plot)
epochs_slider.observe(update_2d_plot)
loss_function_dropdown.observe(update_3d_plot, names="value")
loss_function_dropdown.observe(update_2d_plot, names="value")

plot_3d_loss(LOSS_A_LABEL)

# Display
ui = VBox([label, step_slider, epochs_slider, loss_function_dropdown,
           HBox([output_2d, output_3d])])


def gradient_descent_3D_convergence_interact():
    print(
        "\nDescription:\n"
        "Gradient descent example in 3D. Convergence and Loss Function Visualization.\n"
        "Use the sliders to explore how learning rate, number of epochs, and conditioning of the loss function affect the optimization path.\n\n"
        "- Learning Rate: Controls the magnitude of each update. Too large may cause divergence; too small may slow convergence.\n"
        "- Epochs: Determines the number of optimization steps taken.\n"
        "- Loss function: Adjusts the shape of the loss function. Perfect conditioning (A) results in a circular contour and a straight optimization path, while poor conditioning (B or C) creates an elliptical contour, leading to zig-zag convergence.\n\n"
        "- Left Plot (2D): Shows the contour of the loss function and the optimization path of gradient descent.\n"
        "- Right Plot (3D): Displays the 3D surface of the loss function, highlighting how conditioning changes its shape.\n\n"
        "Notice how the optimization path and loss landscape interact and how it takes more epochs to converge to the function minimum (x: 0, y:0) for ill-conditioned loss functions.\n"
    )
    display(ui)
