<a href="https://colab.research.google.com/github/yahya94812/JAX-Tutorial/blob/main/JAX_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Introduction To JAX**
This notebook provide comprehensive introduction to JAX basics through Python code examples. JAX is a high-performance numerical computing library that combines NumPy's familiar API with the benefits of automatic differentiation and accelerated hardware like GPUs and TPUs.

---
## Table of content :
1. Basic JAX Operations
2. Automatic Differentiation
3. Just-In-Time Compilation (jit)
4. Vectorized Operations (vmap)
5. Combining Transformations
6. Gradient-Based Optimization

# 1. Basic JAX Operations :
Shows how to create and manipulate JAX arrays, which are similar to NumPy arrays but designed for acceleration and transformation.

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import time

In [None]:
# creating JAX array
x = jnp.array([2,4,6,8,10])
y = jnp.ones(5)
print(f"jax array: {x}")
print(f"jax onse: {y}")

jax array: [ 2  4  6  8 10]
jax onse: [1. 1. 1. 1. 1.]


In [None]:
# Basic operations
print("Addition:", x + y)
print("Multiplication:", x * y)
print("Dot product:", jnp.dot(x, y))

Addition: [ 3.  5.  7.  9. 11.]
Multiplication: [ 2.  4.  6.  8. 10.]
Dot product: 30.0


In [None]:
# JAX vs NumPy
numpy_array = np.array([1, 2, 3, 4])
jax_array = jnp.array([1, 2, 3, 4])

print("NumPy array type:", type(numpy_array))
print("JAX array type:", type(jax_array))

NumPy array type: <class 'numpy.ndarray'>
JAX array type: <class 'jaxlib.xla_extension.ArrayImpl'>


# 2. Automatic Differentiation :
 Demonstrates JAX's powerful automatic differentiation capabilities using the grad function, which computes gradients of functions automatically.

In [None]:
# Define a simple function
def square(x):
    return x ** 2

# Compute the gradient of the function
square_grad = grad(square)

# Evaluate the gradient at x=3
x_value = 3.0
grad_value = square_grad(x_value)

print(f"Function: f(x) = x^2")
print(f"Derivative: f'(x) = 2x")
print(f"Gradient at x={x_value}: {grad_value}")
print(f"Expected gradient: {2 * x_value}")

Function: f(x) = x^2
Derivative: f'(x) = 2x
Gradient at x=3.0: 6.0
Expected gradient: 6.0


In [None]:
# A more complex function
def tanh(x):
    return (jnp.exp(x) - jnp.exp(-x)) / (jnp.exp(x) + jnp.exp(-x))

# Compute the gradient of tanh
tanh_grad = grad(tanh)

# Evaluate the gradient at x=2.0
x_value = 2.0
grad_value = tanh_grad(x_value)

print(f"\nFunction: tanh(x)")
print(f"Gradient at x={x_value}: {grad_value}")
print(f"Expected gradient: {1 - tanh(x_value)**2}")


Function: tanh(x)
Gradient at x=2.0: 0.07065093517303467
Expected gradient: 0.07065081596374512


# 3. Just-In-Time Compilation (jit) :
Shows how to use JAX's JIT compiler to significantly speed up function execution.

In [None]:
# Define a function
def slow_function(x):
    # Simulating a complex calculation
    for _ in range(10000):
        x = x + 1e-7
    return x

# Create a JIT-compiled version
fast_function = jit(slow_function)

In [None]:
# Measure execution time

x_value = 5.0

# Warm-up
#  When we use jit in JAX, the first time you call a function,
# -JAX actually compiles it for the specific input shapes.
_ = slow_function(x_value)
_ = fast_function(x_value)

# Time the slow function
start_time = time.time()
result_slow = slow_function(x_value)
slow_time = time.time() - start_time

# Time the fast function
start_time = time.time()
result_fast = fast_function(x_value)
fast_time = time.time() - start_time

print(f"Slow function result: {result_slow}")
print(f"Fast function result: {result_fast}")
print(f"Slow function time: {slow_time:.6f} seconds")
print(f"Fast function time: {fast_time:.6f} seconds")
print(f"Speedup: {slow_time / fast_time:.2f}x")

Slow function result: 5.001000000002804
Fast function result: 5.000999927520752
Slow function time: 0.000373 seconds
Fast function time: 0.000371 seconds
Speedup: 1.01x


# 4. Vectorized Operations with vmap :
 Illustrates how to vectorize functions with vmap, allowing them to operate efficiently on batches of inputs.

In [None]:
# Define a function that operates on scalars
def scalar_function(x, y):
    return x * y + jnp.sin(x)

# Create a vectorized version
vector_function = vmap(scalar_function, in_axes=(0, None))
# Vectorizes over the first argument (x) as indicated by 0 in in_axes
# Keeps the second argument (y) as a scalar, indicated by None in in_axes

# Apply to a vector and a scalar
x_values = jnp.array([1.0, 2.0, 3.0, 4.0])
y_value = 2.0

result = vector_function(x_values, y_value)
# it is similar to [scalar_function(1.0, 2.0), scalar_function(2.0, 2.0), scalar_function(3.0, 2.0), scalar_function(4.0, 2.0)]
print(f"Input vector: {x_values}")
print(f"Scalar: {y_value}")
print(f"Vectorized result: {result}")

# Manual calculation for comparison
manual_result = jnp.array([scalar_function(x, y_value) for x in x_values])
print(f"Manual result: {manual_result}")
print(f"Results match: {jnp.allclose(result, manual_result)}")


Input vector: [1. 2. 3. 4.]
Scalar: 2.0
Vectorized result: [2.841471  4.9092975 6.14112   7.2431974]
Manual result: [2.841471  4.9092975 6.14112   7.2431974]
Results match: True


# 5. Combining JAX Transformations :
Demonstrates how to combine JAX transformations like jit and vmap for maximum performance.

In [None]:
# Define a function for a simple neural network layer
def layer(params, x):
    w, b = params
    return jnp.dot(x, w) + b

# Create a batch version that operates on multiple inputs
batch_layer = vmap(layer, in_axes=(None, 0))

# JIT-compile the batch version
fast_batch_layer = jit(batch_layer)

# Create some test data
batch_size = 256
input_dim = 128
output_dim = 64

W = jnp.ones((input_dim, output_dim))
b = jnp.zeros(output_dim)
params = (W, b)
inputs = jnp.ones((batch_size, input_dim))

# Warm-up
_ = fast_batch_layer(params, inputs)

# Time the execution
start_time = time.time()
result = fast_batch_layer(params, inputs)
batch_time = time.time() - start_time

print(f"Batch size: {batch_size}")
print(f"Input dimension: {input_dim}")
print(f"Output dimension: {output_dim}")
print(f"Output shape: {result.shape}")
print(f"Execution time: {batch_time:.6f} seconds")

Batch size: 256
Input dimension: 128
Output dimension: 64
Output shape: (256, 64)
Execution time: 0.000320 seconds


# 6. Gradient-Based Optimization :
Shows a simple gradient descent optimization example using JAX's automatic differentiation

In [19]:
# Define a simple loss function
def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

# Create gradient function
loss_grad = jit(grad(loss))

# Generate some random data
key = jax.random.PRNGKey(42)
x_data = jax.random.normal(key, (100, 5))
true_params = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = jnp.dot(x_data, true_params) + jax.random.normal(jax.random.PRNGKey(0), (100,)) * 0.1

# Initial parameters
params = jnp.zeros(5)

# Perform gradient descent
n_steps = 100
step_size = 0.1

print(f"Initial loss: {loss(params, x_data, y_data):.4f}")

for i in range(n_steps):
    grads = loss_grad(params, x_data, y_data)
    params = params - step_size * grads

    if i % 20 == 0 or i == n_steps - 1:
        current_loss = loss(params, x_data, y_data)
        print(f"Step {i}, Loss: {current_loss:.4f}")

print(f"True parameters: {true_params}")
print(f"Learned parameters: {params}")

Initial loss: 49.8177
Step 0, Loss: 33.2453
Step 20, Loss: 0.0267
Step 40, Loss: 0.0085
Step 60, Loss: 0.0085
Step 80, Loss: 0.0085
Step 99, Loss: 0.0085
True parameters: [1. 2. 3. 4. 5.]
Learned parameters: [1.008449  1.9874883 2.9950392 4.0158367 4.9924345]


# **Thank You**