# JAX Autograd Tutorial - Video 2
## Automatic Differentiation in JAX

Welcome to the second video in our JAX series! Today we'll explore automatic differentiation with `jax.grad`, `jax.value_and_grad`, and more.

In [1]:
# Cell 1: Setup and imports
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
import numpy as np

print(f"JAX version: {jax.__version__}")
print("Let's explore automatic differentiation in JAX!")

JAX version: 0.6.1
Let's explore automatic differentiation in JAX!


In [2]:
#Simple scalar function - our first gradient
def f(x):
    return x**2

# Get the gradient function
df_dx = grad(f)

# Evaluate at x = 3.0
x = 3.0
gradient_value = df_dx(x)
print(f"f(x) = x^2")
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {gradient_value}")
print(f"Expected: 2 * {x} = {2 * x}")

f(x) = x^2
f(3.0) = 9.0
f'(3.0) = 6.0
Expected: 2 * 3.0 = 6.0


In [3]:
#Function with multiple inputs - partial derivatives
def h(x, y):
    return x**2 + 3*x*y + y**2

# grad() by default takes derivative w.r.t. first argument
dh_dx = grad(h, argnums=0)  # ∂h/∂x
dh_dy = grad(h, argnums=1)  # ∂h/∂y

x, y = 2.0, 1.0
print(f"h(x,y) = x² + 3xy + y²")
print(f"∂h/∂x = 2x + 3y = {dh_dx(x, y)}")
print(f"∂h/∂y = 3x + 2y = {dh_dy(x, y)}")

h(x,y) = x² + 3xy + y²
∂h/∂x = 2x + 3y = 7.0
∂h/∂y = 3x + 2y = 8.0


In [4]:
# value_and_grad - getting both function value and gradient
# Often we need both the function value and its gradient
def h(x, y):
    return x**2 + 3*x*y + y**2

# Instead of calling function + grad separately:
grad_h = grad(h, argnums=(0, 1))
x, y = 2.0, 1.0
value_separate = h(x, y)
grad_separate = grad_h(x, y)

# Use value_and_grad for efficiency:
value_and_grad_h = value_and_grad(h, argnums=(0, 1))
value_together, grad_together = value_and_grad_h(x, y)

print(f"Separate calls:")
print(f"  Value: {value_separate}")
print(f"  Grad: {grad_separate}")
print(f"value_and_grad:")
print(f"  Value: {value_together}")
print(f"  Grad: {grad_together}")
print(f"\nSame results, but value_and_grad is more efficient!")

Separate calls:
  Value: 11.0
  Grad: (Array(7., dtype=float32, weak_type=True), Array(8., dtype=float32, weak_type=True))
value_and_grad:
  Value: 11.0
  Grad: (Array(7., dtype=float32, weak_type=True), Array(8., dtype=float32, weak_type=True))

Same results, but value_and_grad is more efficient!


In [5]:
# value_and_grad in machine learning context
def loss_function(params, x_data, y_data):
    """Simple quadratic loss function"""
    w, b = params
    predictions = w * x_data + b
    return jnp.mean((predictions - y_data)**2)

# Sample data
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.1, 3.9, 6.1, 7.9])  # Noisy y = 2x
params = (1.0, 0.0)  # Initial guess: w=1, b=0

# Get both loss and gradients in one call
loss_and_grad = value_and_grad(loss_function)
current_loss, gradients = loss_and_grad(params, x_data, y_data)

print(f"Current parameters: w={params[0]}, b={params[1]}")
print(f"Loss: {current_loss:.4f}")
print(f"Gradients: dL/dw={gradients[0]:.4f}, dL/db={gradients[1]:.4f}")
print(f"\nThis is exactly what you need for gradient descent!")

Current parameters: w=1.0, b=0.0
Loss: 7.4100
Gradients: dL/dw=-14.9000, dL/db=-5.0000

This is exactly what you need for gradient descent!


In [6]:
# Matrix operations and gradients
def matrix_function(W):
    # Simple quadratic form: trace(W @ W.T)
    return jnp.trace(W @ W.T)

# Use value_and_grad with matrices
val_grad_matrix = value_and_grad(matrix_function)

W = jnp.array([[1.0, 2.0], 
               [3.0, 4.0]])
value, gradient_W = val_grad_matrix(W)
print(f"Input matrix W:\n{W}")
print(f"Function value: {value}")
print(f"Gradient matrix:\n{gradient_W}")
print(f"\nGradient has the same shape as input!")

Input matrix W:
[[1. 2.]
 [3. 4.]]
Function value: 30.0
Gradient matrix:
[[2. 4.]
 [6. 8.]]

Gradient has the same shape as input!


In [7]:
# Higher-order derivatives with value_and_grad
def f(x):
    return x**4 - 2*x**3 + x**2

# First derivative with value
f_val_grad = value_and_grad(f)
# Second derivative
f_prime = grad(f)
f_double_prime = grad(f_prime)

x = 2.0
value, first_deriv = f_val_grad(x)
second_deriv = f_double_prime(x)

print(f"f(x) = x⁴ - 2x³ + x²")
print(f"At x={x}:")
print(f"f({x}) = {value}")
print(f"f'({x}) = {first_deriv}")
print(f"f''({x}) = {second_deriv}")
print(f"\nExpected: f'(x) = 4x³ - 6x² + 2x = {4*x**3 - 6*x**2 + 2*x}")

f(x) = x⁴ - 2x³ + x²
At x=2.0:
f(2.0) = 4.0
f'(2.0) = 12.0
f''(2.0) = 26.0

Expected: f'(x) = 4x³ - 6x² + 2x = 12.0


In [8]:
# Complete gradient descent with value_and_grad
def mse_loss(params, x_data, y_data):
    """Mean squared error for linear regression"""
    w, b = params
    predictions = w * x_data + b
    return jnp.mean((predictions - y_data)**2)

# Create the loss and gradient function
loss_and_grad = value_and_grad(mse_loss)

def gradient_descent_step(params, x_data, y_data, learning_rate=0.1):
    """One step of gradient descent using value_and_grad"""
    loss, grads = loss_and_grad(params, x_data, y_data)
    # Update parameters: θ = θ - α∇L
    new_w = params[0] - learning_rate * grads[0]
    new_b = params[1] - learning_rate * grads[1]
    return (new_w, new_b), loss

# Sample data
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0])  # y = 2x

# Run gradient descent
params = (1.5, 0.5)
print("Gradient descent with value_and_grad:")
for i in range(5):
    params, loss = gradient_descent_step(params, x_data, y_data)
    print(f"Step {i+1}: w={params[0]:.3f}, b={params[1]:.3f}, loss={loss:.4f}")

print(f"\nTarget: w=2.0, b=0.0")

Gradient descent with value_and_grad:
Step 1: w=2.000, b=0.650, loss=0.8750
Step 2: w=1.675, b=0.520, loss=0.4225
Step 3: w=1.903, b=0.578, loss=0.2176
Step 4: w=1.760, b=0.512, loss=0.1239
Step 5: w=1.864, b=0.529, loss=0.0803

Target: w=2.0, b=0.0
