# Optimization Methods

Optimization methods find values that maximize or minimize an objective function, making them useful across disciplines such as engineering, economics, and data science.
Fundamentally, the action principle in physics is an optimization process, where nature selects paths that minimize or extremize an action integral.

## Gradient Descent Methods

Gradient Descent is one of the most widely used optimization techniques, particularly effective for high-dimensional problems in fields such as machine learning.
The method iteratively seeks the minimum of a function by taking steps proportional to the negative of its gradient, guiding the search toward lower function values.
For differentiable objective functions, gradient descent is fundamental in minimizing errors, making it indispensable for training machine learning models and refining physical models in computational astrophysics.

For a function $f(x)$, the gradient $\nabla f(x)$ points in the direction of steepest ascent.
Moving in the opposite direction—along the negative gradient—reduces the function's value. The algorithm updates the parameters iteratively according to:
\begin{align}
x_{n+1} = x_n - \alpha \nabla f(x_n)
\end{align}
where $\alpha$ is the learning rate, controlling the step size.
The choice of $\alpha$ is critical for convergence: 
a large $\alpha$ may cause divergence, where updates overshoot the minimum, while a very small $\alpha$ can lead to slow convergence, requiring many iterations to make meaningful progress.
Proper tuning of $\alpha$ ensures that the algorithm efficiently converges to a minimum without unnecessary oscillations or divergence.

In [None]:
def gd(df, x, alpha, imax=1000):
    for _ in range(imax):
        x -= alpha * df(x)
    return x

In [None]:
# Define the function and its gradient
def f(x):
    return (x - 3)**2 + 4

def df(x):
    return 2 * (x - 3)

# Parameters for gradient descent
x0    = 0.0  # Starting point for optimization
alpha = 0.1

# Run gradient descent
xmin = gd(df, x0, alpha)
print("Approximate minimum:")
print("  xmin  = ",   xmin )
print("f(xmin) = ", f(xmin))

In [None]:
def gd_hist(df, x, alpha, imax=1000):
    X = [x]
    for _ in range(imax):
        X.append(X[-1] - alpha * df(X[-1]))
    return X

In [None]:
import numpy as np
from matplotlib import pyplot as plt

X = np.linspace(0, 6, 6001)
plt.plot(X, f(X))

alpha = 0.1

X = np.array(gd_hist(df, x0, alpha))
print(X[-1])

plt.plot(X, f(X), '-o')
plt.xlim(2.5, 3.5)
plt.ylim(3.95,4.3)

```{exercise}
What will happen if we change the learning rate $\alpha$?

Comment out the plot limits `plt.xlim(2.5, 3.5)` and `plt.ylim(3.95,4.3)` and then try $\alpha = 0.1$, $0.5$, $0.9$, $1.0$, and $1.1$.
```

Similar to our implementation of Newton-Raphson Method, it is possible to employ `JAX` to automatically obtain the derivative.
Here is an updated version of automatic gradient descent.

In [None]:
from jax import grad

def autogd_hist(f, x, alpha, imax=1000):
    df = grad(f)
    X  = [x]
    for _ in range(imax):
        X.append(X[-1] - alpha * df(X[-1]))
    return X

In [None]:
# Define the function and its gradient
def f(x):
    return (x - 3)**2 + 4

# Parameters for gradient descent
x0    = 0.0  # Starting point for optimization
alpha = 0.9

# Run gradient descent
Xmin = np.array(autogd_hist(f, x0, alpha))
print("Approximate minimum:")
print("  xmin  = ",   Xmin[-1] )
print("f(xmin) = ", f(Xmin[-1]))

X = np.linspace(0, 6, 6001)
plt.plot(X,    f(X))
plt.plot(Xmin, f(Xmin), '-o')
plt.xlim(2.5, 3.5)
plt.ylim(3.95,4.3)

## Gradient Descent with JAX for Multiple Dimensions

Multidimensional gradient descent is essential for optimizing functions with multiple parameters, making it the backend of applications such as model fitting and deep learning.

In astrophysics, gradient descent refines models by iteratively adjusting parameters to minimize discrepancies between observed data and theoretical predictions.
For example, in galaxy modeling, each parameter may correspond to a physical property—such as brightness, size, or position—and gradient descent enables efficient optimization to achieve the best fit to observational data.

In deep learning, multidimensional gradient descent is fundamental, as modern neural networks can have millions of parameters.
During training, the algorithm minimizes a loss function that quantifies the difference between the model’s predictions and actual outcomes.
Automatic differentiation with JAX streamlines gradient calculations, allowing practitioners to train complex models without manually computing derivatives.
This capability is particularly valuable for architectures such as convolutional and recurrent neural networks, where gradients must be computed across vast numbers of interconnected parameters.

The following example demonstrates how to use JAX to perform gradient descent on a multivariable function
\begin{align}
f(x, y) = (x - 3)^2 + (y + 4)^2,
\end{align}
where the minimum is at $(x, y) = (3, -4)$.
By tracking each update step, we can visualize the optimization path as it approaches the minimum.

In [None]:
from jax import numpy as jnp
from jax import jit

# Function to perform gradient descent with history tracking
def autogd_hist(f, X, alpha, imax):
    df = jit(grad(f))  # Use JAX to compute gradient
    Xs = [np.array(X)]
    for _ in range(imax):
        Xs.append(Xs[-1] - alpha * df(Xs[-1]))  # Gradient descent update
    return jnp.array(Xs)

In [None]:
# Define a multivariable function
def f(X):
    x, y = X
    return (x - 3)**2 + 2 * (y + 4)**2

# Parameters for gradient descent
X0    = jnp.array([0.0, 0.0]) # Starting point for optimization
alpha = 0.1                   # Learning rate
imax  = 100                   # Number of iterations

# Run gradient descent with history tracking
Xs = autogd_hist(f, X0, alpha, imax)
print("Approximate minimum:")
print("  xmin  =",   Xs[-1] )
print("f(xmin) =", f(Xs[-1]))

# Plot the function and gradient descent path
x_vals = jnp.linspace(-1, 7, 100)
y_vals = jnp.linspace(-8, 0, 100)
X, Y   = jnp.meshgrid(x_vals, y_vals)
Z      = f([X, Y])

plt.contour(X, Y, Z, levels=20)
plt.plot(Xs[:,0], Xs[:,1], '-o', color='red')
plt.xlabel('x')
plt.ylabel('y')
plt.gca().set_aspect('equal')

Because we minimize $f(x,y)$, it can be seen as the loss function.
Hence we can plot the evolution of the loss:

In [None]:
plt.loglog(f(Xs.T))
plt.xlabel('Step')
plt.ylabel('Loss f(x,y)')

To demonstrate a more complex optimization scenario, let's consider fitting a multi-parameter model to noisy data.
We will use polynomial regression as our example, where we fit a polynomial curve to data points by optimizing the coefficients.
This is a non-trivial problem because, as the degree of the polynomial increases, the number of parameters grows, resulting in a high-dimensional optimization task.

In [None]:
groundtruth = np.array([1.2, -3, 0.5, 1.0, -1.8, 2.0, -0.1])

Xdata = np.linspace(-1, 1, 1_000)
Ytrue = sum(c * Xdata**i for i, c in enumerate(groundtruth))
Ydata = Ytrue + np.random.normal(scale=0.1, size=Xdata.shape)

In [None]:
plt.plot(Xdata, Ytrue)
plt.plot(Xdata, Ydata)

In [None]:
# Define polynomial model
def model(Xs, Cs):
    return sum(c * Xs**i for i, c in enumerate(Cs))

# Define the objective function
def chi2(Cs):
    Ymodel = model(Xdata, Cs)
    return jnp.mean((Ymodel - Ydata)**2)

# Parameters for gradient descent
C0    = jnp.zeros(len(groundtruth)) # Start with zeros as initial coefficients
alpha = 0.1                         # Learning rate
imax  = 1000                        # Number of iterations

Cs = autogd_hist(chi2, C0, alpha, imax)
%timeit -r1 Cs = autogd_hist(chi2, C0, alpha, imax)

print("Optimized coefficients:", Cs[-1])
print("True coefficients:",      groundtruth)
print("Mean Squared Error:",     np.mean((groundtruth - Cs[-1])**2))

In [None]:
skip = 20
plt.scatter(Xdata[::skip], Ydata[::skip], color='blue', label='Noisy Data', alpha=0.5)
plt.plot(Xdata, Ytrue, 'g--', label='True Polynomial')
for i, Ci in enumerate(Cs[::skip]):
    Yfit = model(Xdata, Ci)
    plt.plot(Xdata, Yfit, 'r', alpha=skip*i/imax, label='Fitted Polynomial' if skip*i == imax else '')
plt.xlabel("x")
plt.ylabel("y")
plt.legend()

Let's also plot $\chi^2$:

In [None]:
Chi2 = [chi2(Ci) for Ci in Cs]

In [None]:
plt.loglog(Chi2)
plt.xlabel('Step')
plt.ylabel('Chi2')