In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# =========================
# Function and gradients
# =========================
def f(x):
    return x**4 - 4*x**2 + x

def df(x):
    return 4*x**3 - 8*x + 1

def sgd_grad(x, noise_std=0.5):
    """
    Simulated stochastic gradient:
    true gradient + Gaussian noise
    """
    return df(x) + np.random.normal(0, noise_std)


# =========================
# Domain and loss curve
# =========================
x_plot = np.linspace(-3, 3, 500)
y_plot = f(x_plot)

# =========================
# Hyperparameters
# =========================
eta = 0.01
n_frames = 300

# =========================
# Initial points
# =========================
x_bgd = 2.5
x_sgd = 2.5

x_bgd_hist = []
x_sgd_hist = []

# =========================
# Figure setup
# =========================
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(x_plot, y_plot, label="Loss function f(x)")

bgd_point, = ax.plot([], [], "ro", label="BGD")
sgd_point, = ax.plot([], [], "go", label="SGD")

bgd_path, = ax.plot([], [], "r--", alpha=0.6)
sgd_path, = ax.plot([], [], "g--", alpha=0.6)

ax.set_xlim(-3, 3)
ax.set_ylim(min(y_plot) - 1, max(y_plot) + 1)
ax.set_xlabel("x")
ax.set_ylabel("f(x)")
ax.set_title("SGD vs BGD Convergence (2D Projection)")
ax.legend()
ax.grid(True)

# =========================
# Animation update
# =========================
def update(frame):
    global x_bgd, x_sgd

    # ----- BGD update -----
    grad_bgd = df(x_bgd)
    x_bgd = x_bgd - eta * grad_bgd
    x_bgd_hist.append(x_bgd)

    # ----- SGD update -----
    grad_sgd = sgd_grad(x_sgd)
    x_sgd = x_sgd - eta * grad_sgd
    x_sgd_hist.append(x_sgd)

    # Update points
    bgd_point.set_data(x_bgd, f(x_bgd))
    sgd_point.set_data(x_sgd, f(x_sgd))

    # Update paths
    bgd_path.set_data(x_bgd_hist, f(np.array(x_bgd_hist)))
    sgd_path.set_data(x_sgd_hist, f(np.array(x_sgd_hist)))

    return bgd_point, sgd_point, bgd_path, sgd_path


# =========================
# Run animation
# =========================
ani = FuncAnimation(
    fig,
    update,
    frames=n_frames,
    interval=50,
    blit=True
)

plt.show()


: 