# Install JAX

To run this notebook, JAX must be installed. If running locally, JAX can be installed for CPU with:

```bash
pip install jax jaxlib
```

For GPU support, see the [JAX installation guide](https://github.com/google/jax#installation) for the correct pip command for your CUDA version.

In [2]:
import jax
import jax.numpy as jnp
from jax import random


# Define the Least Squares Problem

We want to solve the linear system $Ax = b$ in the least squares sense, i.e., find $x$ that minimizes $\|Ax - b\|^2$. We'll generate a random matrix $A$ and vector $b$ for this demonstration.

In [3]:
key = random.PRNGKey(0)
A = random.normal(key, (5, 2))  # 5 equations, 2 unknowns
x_true = jnp.array([2.0, -3.0])
b = A @ x_true + 0.1 * random.normal(key, (5,))  # Add a bit of noise
print("A =\n", A)
print("b =", b)




A =
 [[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]
 [-0.49529874  0.4943786 ]
 [ 0.6643493  -0.9501635 ]]
b = [-2.6682458  -0.42881033  3.22509    -2.481595    4.1967983 ]


# Solve the Least Squares Problem Using JAX

We'll use JAX's linear algebra routines to solve for $x$ that minimizes $\|Ax - b\|^2$.

In [4]:
# Compute the least squares solution using the pseudo-inverse
x_lstsq = jnp.linalg.pinv(A) @ b
print("Estimated x:", x_lstsq)
print("True x:", x_true)
print("Residual norm:", jnp.linalg.norm(A @ x_lstsq - b))


Estimated x: [ 2.022512  -2.9543445]
True x: [ 2. -3.]
Residual norm: 0.22407445


# Solving Least Squares with Gradient Descent

Instead of using the pseudo-inverse, we can solve the least squares problem by minimizing the loss $L(x) = \|Ax - b\|^2$ using gradient descent. JAX makes it easy to compute gradients and perform optimization.

In [5]:
import jax
from jax import grad

# Define the loss function
loss = lambda x: jnp.sum((A @ x - b) ** 2)

# Compute the gradient of the loss
loss_grad = grad(loss)

# Gradient descent loop
x_gd = jnp.zeros_like(x_true)
learning_rate = 0.05
num_steps = 100

for i in range(num_steps):
    x_gd -= learning_rate * loss_grad(x_gd)
    if i % 20 == 0:
        print(f"Step {i}, loss: {loss(x_gd):.4f}, x: {x_gd}")

print("\nGradient Descent Solution:", x_gd)
print("True x:", x_true)
print("Residual norm:", jnp.linalg.norm(A @ x_gd - b))


Step 0, loss: 15.1927, x: [ 0.04415018 -1.3719759 ]
Step 20, loss: 0.0508, x: [ 2.0087342 -2.9464455]
Step 40, loss: 0.0502, x: [ 2.0224242 -2.9542947]
Step 60, loss: 0.0502, x: [ 2.022511  -2.9543443]
Step 80, loss: 0.0502, x: [ 2.022511  -2.9543443]

Gradient Descent Solution: [ 2.022511  -2.9543443]
True x: [ 2. -3.]
Residual norm: 0.22407433
