In [26]:
"""
Minimal linear-regression example with JAX.

Requires:  pip install --upgrade "jax[cpu]"
(or the GPU/CUDA build if you have CUDA installed.)
"""

import jax
import jax.numpy as jnp
from jax import grad, jit

# -------------------------------------------------------------------
# Synthetic “dataset”: one 5-D input mapped to one 3-D target
# -------------------------------------------------------------------
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])      # shape (5,)
y_true = jnp.array([2.0, -1.0, 0.5])          # shape (3,)

# -------------------------------------------------------------------
# Model parameters: weight matrix W (3×5) and bias b (3,)
# -------------------------------------------------------------------
key = jax.random.PRNGKey(0)
W_init = jax.random.normal(key, (3, 5)) * 0.1
b_init = jnp.zeros(3)
params = (W_init, b_init)


In [27]:
y = jnp.dot(W_init, x) + b_init
y

Array([ 1.0946617 , -0.80353975, -1.1520827 ], dtype=float32)

In [28]:

# -------------------------------------------------------------------
# Forward pass and loss
# -------------------------------------------------------------------
@jax.jit
def predict(params, x):
    W, b = params
    return jnp.dot(W, x) + b                # linear map 5->3

@jax.jit
def mse_loss(params, x, y_true):
    y_pred = predict(params, x)
    return jnp.mean((y_pred - y_true) ** 2)


In [29]:
mse_loss(params, x, y_true)  # loss for initial params

Array(1.1958704, dtype=float32)

In [30]:

# -------------------------------------------------------------------
# Training loop using vanilla stochastic-gradient descent
# -------------------------------------------------------------------
learning_rate = 1e-2
grad_loss = jax.jit(jax.grad(mse_loss))             # ∇loss w.r.t. (W, b


In [31]:
grad_loss(params, x, y_true)  # gradient for initial params

(Array([[-0.6035589 , -1.2071178 , -1.8106767 , -2.4142356 , -3.0177946 ],
        [ 0.1309735 ,  0.261947  ,  0.3929205 ,  0.523894  ,  0.65486753],
        [-1.1013885 , -2.202777  , -3.3041654 , -4.405554  , -5.5069423 ]],      dtype=float32),
 Array([-0.6035589,  0.1309735, -1.1013885], dtype=float32))

In [32]:

for step in range(11):
    grads = grad_loss(params, x, y_true)    # compute gradients

    # Fancy way to update params:
    # params = tuple(p - learning_rate * g    # SGD update
    #                for p, g in zip(params, grads))

    # Less fancy way to update params:
    W, b = params
    W_grad, b_grad = grads
    W_update = W - learning_rate * W_grad
    b_update = b - learning_rate * b_grad
    params = (W_update, b_update)

    print(f"step {step:4d} | loss = {mse_loss(params, x, y_true):.10f}")

# -------------------------------------------------------------------
# Results
# -------------------------------------------------------------------
W_opt, b_opt = params
print("\nOptimized weights W*:\n", W_opt)
print("\nOptimized bias b*:\n", b_opt)
print("\nPrediction after training:", predict(params, x))
print("Target:", y_true)

step    0 | loss = 0.4696317315
step    1 | loss = 0.1844295710
step    2 | loss = 0.0724275336
step    3 | loss = 0.0284430813
step    4 | loss = 0.0111699272
step    5 | loss = 0.0043865554
step    6 | loss = 0.0017226525
step    7 | loss = 0.0006765013
step    8 | loss = 0.0002656693
step    9 | loss = 0.0001043317
step   10 | loss = 0.0000409718

Optimized weights W*:
 [[ 0.0245547   0.22312075  0.07777813  0.17638329  0.113793  ]
 [-0.08609448  0.0578374  -0.11399366 -0.09219912 -0.06283305]
 [ 0.09230851  0.14018224  0.05519871  0.0049707  -0.01943018]]

Optimized bias b*:
 [ 0.01607213 -0.00348768  0.0293288 ]

Prediction after training: [ 1.9947008  -0.9988501   0.49032986]
Target: [ 2.  -1.   0.5]
