Install JAX and Flax

In [1]:
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax optax transformers datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.7/129.7 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Introduction to JAX for ML

1️⃣ Import Required Libraries

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np


2️⃣ Check Available Devices

In [2]:
jax.devices()


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

3️⃣ Implement a Simple Function and Compute the Gradient

In [3]:
def loss_fn(x):
    return x**2 + 3*x + 2  # Quadratic function

grad_loss = grad(loss_fn)  # Compute derivative

x = jnp.array(5.0)
print("Loss:", loss_fn(x))
print("Gradient at x=5:", grad_loss(x))


Loss: 42.0
Gradient at x=5: 13.0


Explanation for grad (Automatic Differentiation in JAX)

What is grad?
grad is a function in JAX that computes the gradient (derivative) of a function with respect to its input.
It is useful for training neural networks and optimizing functions in machine learning.

Explanation:
The function loss_fn(x) = x² + 3x + 2 is differentiable.
grad(loss_fn) returns another function that computes ∂(loss_fn)/∂x.
The gradient at x = 5 is calculated using automatic differentiation.

4️⃣ Use jit for Faster Execution

In [4]:
@jit
def compute_square(x):
    return x ** 2

x = jnp.array([2.0, 3.0, 4.0])
print(compute_square(x))  # JIT-optimized execution


[ 4.  9. 16.]


Explanation for jit (Just-In-Time Compilation in JAX)

What is jit?
jit compiles your Python functions into highly optimized machine code.
It speeds up numerical computations, especially in deep learning.

Explanation:
The function compute_square(x) = x² is decorated with @jit, meaning JAX will compile it using XLA (Accelerated Linear Algebra).
This makes execution much faster, especially on GPUs/TPUs.
JIT Compilation is useful for training large models efficiently.

JAX’s NumPy Replacement

In [5]:
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x)  # Compute sin(x) using jax.numpy
print(y)


[0.84147096 0.9092974  0.14112003]


What is jax.numpy?
jax.numpy (jnp) is a drop-in replacement for NumPy, optimized for automatic differentiation (grad) and GPU/TPU acceleration.
Functions in jax.numpy behave like NumPy but can be JIT-compiled and auto-differentiated.


Why Use jax.numpy Instead of NumPy?
Supports GPU & TPU acceleration automatically.
Works seamlessly with grad and jit for machine learning applications.
Allows vectorized computations for faster performance.