
## What is JAX?

JAX is a high-performance numerical computing library developed by Google Research.
It combines NumPy's familiar API with automatic differentiation and hardware acceleration.
The name "JAX" stands for "Just After eXecution" - highlighting its just-in-time compilation strategy.
JAX is designed for high-performance machine learning research and scientific computing.

JAX's Core Principles

"NumPy on steroids" - familiar API but with added capabilities
Functional programming paradigm - emphasizes pure functions and immutability
Composable function transformations - gradients, vectorization, and JIT compilation
Designed for modern hardware acceleration (GPUs and TPUs)

Key Differences Between JAX and PyTorch

- JAX: Employs a functional programming approach with immutable arrays
PyTorch: Uses an object-oriented approach with mutable tensors
"In JAX, you don't change data, you transform it into new data"

- JAX: Uses forward-mode and reverse-mode automatic differentiation through its grad transformation
PyTorch: Uses dynamic computation graphs with eager execution
JAX provides more flexibility in differentiation styles (forward vs reverse)

- JAX: Uses XLA (Accelerated Linear Algebra) for compilation and optimization
PyTorch: Traditionally used eager execution, though now has TorchScript/JIT
JAX's compilation strategy often leads to better performance for certain workloads


- JAX: Emphasizes just-in-time compilation for performance
PyTorch: Traditionally emphasizes eager execution for flexibility
JAX code often requires a different mindset due to compilation requirements



# Introduction to JAX Arrays

Welcome to this tutorial on JAX Arrays! This notebook covers the fundamental concepts and operations related to arrays in JAX, which are the basic building blocks for numerical computing in the library.


In [1]:
# Install JAX if not already installed
# !pip install jax jaxlib

# Import the necessary libraries
import jax
import jax.numpy as jnp
import numpy as np

# Check the JAX version
print(f"JAX version: {jax.__version__}")

# Check if GPU is available
print(f"Available devices: {jax.devices()}")

JAX version: 0.6.1
Available devices: [CudaDevice(id=0)]


## 1. Creating JAX Arrays

JAX arrays are similar to NumPy arrays but are designed to work efficiently with JAX's transformations and hardware acceleration. Let's see how to create JAX arrays.

In [2]:
# Creating JAX arrays from scratch

# Create an array of ones
ones = jnp.ones((2, 3))
print("Array of ones:")
print(ones)
print(f"Shape: {ones.shape}\n")

# Create an identity matrix
identity = jnp.eye(3)
print("3x3 Identity matrix:")
print(identity)
print(f"Shape: {identity.shape}")

Array of ones:
[[1. 1. 1.]
 [1. 1. 1.]]
Shape: (2, 3)

3x3 Identity matrix:
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Shape: (3, 3)


In [3]:
# Creating JAX arrays from sequences and NumPy arrays

# From a nested list (2D array)
nested_list_array = jnp.array([[1, 2, 3], [4, 5, 6]])
print("JAX array from nested list:")
print(nested_list_array)
print(f"Shape: {nested_list_array.shape}\n")

# From a NumPy array
numpy_array = np.random.rand(3, 3)
jax_array = jnp.array(numpy_array)
print("Original NumPy array:")
print(numpy_array)
print(f"Type: {type(numpy_array)}\n")
print("Converted JAX array:")
print(jax_array)
print(f"Type: {type(jax_array)}")

JAX array from nested list:
[[1 2 3]
 [4 5 6]]
Shape: (2, 3)

Original NumPy array:
[[0.31940963 0.05591559 0.2763571 ]
 [0.83519235 0.61999807 0.91921544]
 [0.30241709 0.81350188 0.02125688]]
Type: <class 'numpy.ndarray'>

Converted JAX array:
[[0.31940964 0.05591559 0.2763571 ]
 [0.8351924  0.61999804 0.91921544]
 [0.3024171  0.8135019  0.02125688]]
Type: <class 'jaxlib._jax.ArrayImpl'>


## 2. Array Properties and Information

Understanding array properties is essential for working effectively with JAX arrays.

In [4]:
# Create a sample array
sample_array = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])

# Display the array
print("Sample array:")
print(sample_array)

# Basic properties
print(f"\nShape: {sample_array.shape}")
print(f"Dimensions: {sample_array.ndim}")
print(f"Size (total number of elements): {sample_array.size}")
print(f"Data type: {sample_array.dtype}")

# Get information about the array's memory layout
print(f"Byte size: {sample_array.nbytes} bytes")

# Check if array is a JAX tracer (important for JIT compilation)
print(f"Is JAX tracer: {isinstance(sample_array, jax.core.Tracer)}")

Sample array:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]

Shape: (3, 3)
Dimensions: 2
Size (total number of elements): 9
Data type: float32
Byte size: 36 bytes
Is JAX tracer: False


## 3. Array Indexing and Slicing

JAX arrays support the same indexing and slicing operations as NumPy arrays, but with important differences in handling due to JAX's immutability.

In [6]:
# Create a 3x4 array for demonstration
arr = jnp.arange(12).reshape(3, 4)
print("Original array:")
print(arr)

# Basic indexing
print(f"\nElement at (1, 2): {arr[1, 2]}")

# Slicing along rows and columns
print(f"\nFirst row: {arr[0]}")
print(f"First column: {arr[:, 0]}")
print(f"\nSubarray (rows 0-1, columns 1-3):")
print(arr[0:2, 1:3])

# Fancy indexing
print(f"\nUsing fancy indexing (rows 0 and 2):")
print(arr[jnp.array([0, 2])])

# Boolean indexing
mask = arr > 5
print(f"\nMask for elements > 5:")
print(mask)
print(f"Elements where mask is True:")
print(arr[mask])

Original array:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]

Element at (1, 2): 6

First row: [0 1 2 3]
First column: [0 4 8]

Subarray (rows 0-1, columns 1-3):
[[1 2]
 [5 6]]

Using fancy indexing (rows 0 and 2):
[[ 0  1  2  3]
 [ 8  9 10 11]]

Mask for elements > 5:
[[False False False False]
 [False False  True  True]
 [ True  True  True  True]]
Elements where mask is True:
[ 6  7  8  9 10 11]


## 4. Basic Array Operations

JAX provides most of NumPy's functionality through the `jax.numpy` module.

In [7]:
# Create two arrays
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])

print("Array a:")
print(a)
print("\nArray b:")
print(b)

# Element-wise operations
print("\nElement-wise addition (a + b):")
print(a + b)

print("\nElement-wise subtraction (a - b):")
print(a - b)

print("\nElement-wise multiplication (a * b):")
print(a * b)

print("\nElement-wise division (a / b):")
print(a / b)

# Matrix operations
print("\nMatrix multiplication (a @ b):")
print(jnp.matmul(a, b))  # or a @ b

# Scalar operations
print("\nMultiply array by scalar (a * 2):")
print(a * 2)

Array a:
[[1 2]
 [3 4]]

Array b:
[[5 6]
 [7 8]]

Element-wise addition (a + b):
[[ 6  8]
 [10 12]]

Element-wise subtraction (a - b):
[[-4 -4]
 [-4 -4]]

Element-wise multiplication (a * b):
[[ 5 12]
 [21 32]]

Element-wise division (a / b):
[[0.2        0.33333334]
 [0.42857146 0.5       ]]

Matrix multiplication (a @ b):
[[19 22]
 [43 50]]

Multiply array by scalar (a * 2):
[[2 4]
 [6 8]]


## 5. Broadcasting in JAX

Broadcasting is a powerful feature that allows operations between arrays of different shapes. JAX follows NumPy's broadcasting rules.

In [8]:
# Broadcasting examples

# 1. Scalar with array
arr = jnp.array([[1, 2, 3], [4, 5, 6]])
scalar = 10
print("Original array:")
print(arr)
print(f"\nArray + scalar ({scalar}):")
print(arr + scalar)

# 2. Row vector with 2D array
row = jnp.array([10, 20, 30])
print("\nRow vector:")
print(row)
print("\nAdding row vector to each row of 2D array:")
print(arr + row)

# 3. Column vector with 2D array
col = jnp.array([[100], [200]])
print("\nColumn vector:")
print(col)
print("\nAdding column vector to each column of 2D array:")
print(arr + col)

Original array:
[[1 2 3]
 [4 5 6]]

Array + scalar (10):
[[11 12 13]
 [14 15 16]]

Row vector:
[10 20 30]

Adding row vector to each row of 2D array:
[[11 22 33]
 [14 25 36]]

Column vector:
[[100]
 [200]]

Adding column vector to each column of 2D array:
[[101 102 103]
 [204 205 206]]


## 6. Array Transformations

JAX provides various functions to transform arrays, similar to NumPy.

In [9]:
# Create a sample array
arr = jnp.arange(12).reshape(3, 4)
print("Original array:")
print(arr)

# Reshape the array
reshaped = arr.reshape(4, 3)
print("\nReshaped to 4x3:")
print(reshaped)

# Transpose the array
transposed = arr.T  # or jnp.transpose(arr)
print("\nTransposed:")
print(transposed)

# Flatten the array
flattened = arr.flatten()  # or arr.ravel()
print("\nFlattened:")
print(flattened)

# Stack arrays
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
print("\nVertical stack:")
print(jnp.vstack([a, b]))
print("\nHorizontal stack:")
print(jnp.hstack([a, b]))

Original array:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]

Reshaped to 4x3:
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]

Transposed:
[[ 0  4  8]
 [ 1  5  9]
 [ 2  6 10]
 [ 3  7 11]]

Flattened:
[ 0  1  2  3  4  5  6  7  8  9 10 11]

Vertical stack:
[[1 2 3]
 [4 5 6]]

Horizontal stack:
[1 2 3 4 5 6]


## 7. Random Arrays in JAX

JAX's approach to random number generation differs from NumPy, using a stateless, functional design.

In [10]:
# Import JAX's random module
from jax import random

# Create a PRNG key
key = random.key(42)  # 42 is the seed
print("PRNG key:")
print(key)

# Generate random arrays
print("\nRandom uniform array [0, 1):")
uniform = random.uniform(key, shape=(2, 3))
print(uniform)

# Split the key for multiple random operations
key, subkey = random.split(key)
print("\nRandom normal array:")
normal = random.normal(subkey, shape=(2, 3))
print(normal)

# Generate random integers
key, subkey = random.split(key)
print("\nRandom integers in range [0, 10):")
integers = random.randint(subkey, shape=(2, 3), minval=0, maxval=10)
print(integers)

PRNG key:
Array((), dtype=key<fry>) overlaying:
[ 0 42]

Random uniform array [0, 1):
[[0.48870957 0.6797972  0.6162715 ]
 [0.5610161  0.4506446  0.58586586]]

Random normal array:
[[ 0.60576403  0.7990441  -0.908927  ]
 [-0.63525754 -1.2226585  -0.83226097]]

Random integers in range [0, 10):
[[7 9 8]
 [3 8 7]]


## 8. JAX's Array Immutability

One key difference between JAX and NumPy is that JAX arrays are immutable. Let's explore how this affects operations.

In [11]:
# Create a JAX array
jax_arr = jnp.array([1, 2, 3, 4])
print("Original JAX array:")
print(jax_arr)

# Creating a new array (doesn't modify the original)
new_arr = jax_arr.at[0].set(99)
print("\nNew array after setting index 0 to 99:")
print(new_arr)
print("\nOriginal array (unchanged):")
print(jax_arr)

# Multiple updates
updated_arr = jax_arr.at[1:3].set(jnp.array([88, 77]))
print("\nArray after updating indices 1 and 2:")
print(updated_arr)

# Create a 2D array
arr_2d = jnp.arange(9).reshape(3, 3)
print("\nOriginal 2D array:")
print(arr_2d)

# Update a specific element
updated_2d = arr_2d.at[1, 2].set(99)
print("\nUpdated 2D array:")
print(updated_2d)

Original JAX array:
[1 2 3 4]

New array after setting index 0 to 99:
[99  2  3  4]

Original array (unchanged):
[1 2 3 4]

Array after updating indices 1 and 2:
[ 1 88 77  4]

Original 2D array:
[[0 1 2]
 [3 4 5]
 [6 7 8]]

Updated 2D array:
[[ 0  1  2]
 [ 3  4 99]
 [ 6  7  8]]


## 9. Device Management

JAX arrays can be moved between devices (CPU, GPU, TPU) explicitly.

In [19]:
# Create a JAX array
arr = jnp.ones((3, 3))
print(f"Default device: {arr.addressable_data(0).device}")

# Move to CPU explicitly
cpu_arr = jax.device_put(arr, jax.devices('cpu')[0])
print(f"\nAfter explicit CPU placement: {cpu_arr.addressable_data(0).device}")

# Conditionally move to GPU (if available)
gpu_devices = jax.devices('gpu')
if gpu_devices:
    gpu_arr = jax.device_put(arr, gpu_devices[0])
    print(f"\nAfter GPU placement: {gpu_arr.addressable_data(0).device}")
else:
    print("\nNo GPU devices available")

Default device: cuda:0

After explicit CPU placement: TFRT_CPU_0

After GPU placement: cuda:0


## 11. JAX-Specific Array Features


Understanding asynchronous execution in JAX

In [20]:
import time
import jax.numpy as jnp  # You need to import JAX numpy

# Create a large array
large_array = jnp.ones((5000, 5000))
print("Large array created")

# Example 1: Without block_until_ready()
start_time = time.time()
result1 = large_array @ large_array  # Matrix multiplication
elapsed_without_blocking = time.time() - start_time
print("Time without blocking:", round(elapsed_without_blocking, 6), "seconds")

# Example 2: With block_until_ready()
start_time = time.time()
result2 = large_array @ large_array  # Matrix multiplication
result2.block_until_ready()  # Wait for computation to finish
elapsed_with_blocking = time.time() - start_time
print("Time with blocking:", round(elapsed_with_blocking, 6), "seconds")


Large array created
Time without blocking: 1.273312 seconds
Time with blocking: 0.01809 seconds
