In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Define the function
def function(x):
    return torch.sin(x) + x**2

# Convert the function to a PyTorch tensor
x = torch.tensor([2.0], requires_grad=True)  # Starting point

# Define the SGD optimizer with momentum
optimizer = torch.optim.SGD([x], lr=0.1, momentum=0.9)

# Create a figure and axis for the animation
fig, ax = plt.subplots()
ax.set_xlim(-10, 10)
ax.set_ylim(-1, 6)
line, = ax.plot([], [], 'ro', label='Optimization Path')
func_line, = ax.plot([], [], label='Function: f(x) = sin(x) + 0.5x')

# Initialization function for the animation
def init():
    line.set_data([], [])
    func_line.set_data([], [])
    return line, func_line

# Animation update function
x_values = [2.0]
def animate(i):
    optimizer.zero_grad()  # Clear gradients from previous iteration
    
    loss = function(x)  # Compute the loss
    loss.backward()  # Compute gradients with respect to x
    optimizer.step()  # Update x using the computed gradients
    x_values.append(x.item())  # Store the current x value
    
    line.set_data(x_values, [function(torch.tensor(val)).item() for val in x_values])
    func_line.set_data(np.linspace(-6, 6, 400), function(torch.tensor(np.linspace(-6, 6, 400))).detach().numpy())
    return line, func_line

# Create the animation
num_iterations = 100
anim = FuncAnimation(fig, animate, init_func=init, frames=num_iterations, interval=200, blit=True)

# Show the animation
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('Optimization using SGD with Momentum in PyTorch (Animation)')
plt.legend()
plt.grid(True)
plt.show()
