<a href="https://colab.research.google.com/github/scottmunson/Portfolio/blob/main/JAX_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX

What is Jax? (Taken from their GitHub page)

JAX is Autograd and XLA (accelerated linear algebra), brought together for high-performance machine learning research.

JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit.

GitHub:

* https://github.com/google/jax
* https://github.com/google/flax

## Dependencies

In [None]:
!pip install -q flax jax numpy optax sklearn


# Basics

To illustrate the basic concepts of Jax, one first needs to start with numpy.
Numpy runs on the CPU, while JAX runs on the GPU.

So for these basics, we are going to do large matrix multiplication.

In [None]:
import numpy as np
import time

# Size of the matrices
size = 5000

# Generating two large matrices
matrix1 = np.random.rand(size, size)
matrix2 = np.random.rand(size, size)

start = time.time()

# Performing a matrix multiplication operation
result = np.dot(matrix1, matrix2)

end = time.time()

print(f'Time taken for matrix multiplication of size {size}x{size} is {end - start} seconds.')

.... and now for JAX:

In [None]:
import jax.numpy as jnp # This is literally as simple as it gets
from jax import random, jit
import time

# Size of the matrices
size = 5000

key = random.PRNGKey(0) # Random seed

# Generating two large matrices
matrix1 = random.normal(key, (size, size)) # Slight differences in how random is used vs np.random
matrix2 = random.normal(key, (size, size))

# JIT compile the dot operation for efficiency
@jit
def matmul_jit(A, B):
    return jnp.dot(A, B)

start = time.time()

# Performing a matrix multiplication operation
result = matmul_jit(matrix1, matrix2)

end = time.time()

print(f'Time taken for matrix multiplication of size {size}x{size} is {end - start} seconds.')


# Terminology

**JAX** = Numerical, GPU-driven, computational library<br>
**FLAX** = Machine learning framework made with JAX<br>
**JIT** = Just-in-time (JIT) compilation is a concept that comes from the world of computer programming, where it is used to convert bytecode into machine code right before execution. This approach contrasts with ahead-of-time (AOT) compilation, where the bytecode is compiled into machine code before runtime. (*Plagiarised from GPT)

# Similarities to Numpy

Here we illustrate the numpy implementation for initialising two matrices and then multiplying them.

In [None]:
import numpy as np

# Initialize two arrays
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[7, 8], [9, 10], [11, 12]])

# Perform matrix multiplication
z = np.dot(x, y)

print(z)


Here we illustrate the JAX implementation for initialising two matrices and then multiplying them.

You will see... they are exactly the same...

In [None]:
import jax.numpy as jnp # The only difference

# Initialize two arrays
x = jnp.array([[1, 2, 3], [4, 5, 6]])
y = jnp.array([[7, 8], [9, 10], [11, 12]])

# Perform matrix multiplication
z = jnp.dot(x, y)

print(z)


# Differences

#### Device Management

Since JAX runs on the GPU, you have to manually move data between the CPU and GPU. This is similar to PyTorch.

Take note, if you have a numpy application and want to move to JAX, this is probably the most work you have to do. But the error is simple and recognisable. 

```sh
<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)``
```

In [None]:
import jax

# Print the list of available devices
print(jax.devices())

# Create an array on the first available device (which could be a CPU, GPU, or TPU)
x = jax.device_put(1.0, device=jax.devices()[0]) # Here we put a scalar value on the GPU

# Check which device the array is on
print(x.device_buffer.device())

#### Immutability

The ability to mutate an existing symbol

In NumPy, you can modify arrays in-place, for example:

In [None]:
import numpy as np

# Initialize an array
x = np.array([1, 2, 3])

# Modify an element in-place
x[0] = 10

print(x)  # prints: array([10, 2, 3])


...however, you can not do this in JAX

In [None]:
import jax.numpy as jnp

# Initialize an array
x = jnp.array([1, 2, 3])

# Attempt to modify an element in-place
# x[0] = 10  # raises: an error!

...instead, do this:

In [None]:
import jax.numpy as jnp

# Initialize an array
x = jnp.array([1, 2, 3])

# "Modify" the array
x.at[0].set(10) # This is the correct way

print(x)  # prints: array([10.,  2.,  3.], dtype=float32)

#### Features Numpy Doesn't Have

*   Automatic Differentiation (Autograd)
*   List item



In [None]:
from jax import grad
import jax.numpy as jnp

# Define a simple function
def f(x):
    return jnp.sin(x)

# Compute its gradient
df = grad(f) # -> Here we get the gradient. Take note, it is a functor, returning a function to get the gradient, give input

# Evaluate the gradient at x = 1.0
print(df(1.0))


# Machine Learning with Jax

#### Dependencies

In [None]:
import jax
import optax # Used for optimisers
from jax import numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

#### Load data

In [None]:
# Load and preprocess data
iris = datasets.load_iris()
X = iris.data
y = iris.target
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
num_classes = len(jnp.unique(y_train))

#### Convert to JAX arrays

In [None]:
# Convert to JAX arrays
X_train = jnp.array(X_train)
y_train = onehot(y_train, num_classes)
X_test = jnp.array(X_test)
y_test = onehot(y_test, num_classes)

#### Model

In [None]:
# Define a simple Feedforward Neural Network
class FFNN(nn.Module):
    hidden_units: int
    output_units: int

    def setup(self):
        self.hidden = nn.Dense(self.hidden_units)
        self.out = nn.Dense(self.output_units)

    def __call__(self, x):
        x = self.hidden(x)
        x = nn.relu(x)
        x = self.out(x)
        return nn.log_softmax(x)

#### Setup model, optimiser, loss

In [None]:
# Initialize network and optimizer
learning_rate = 0.01
hidden_units = 64
output_units = num_classes
model = FFNN(hidden_units, output_units)
params = model.init(jax.random.PRNGKey(0), X_train[0])
optim = optax.adam(learning_rate=learning_rate)
tx = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,
  tx=optim
)

#### Metrics

In [None]:
def cross_entropy_loss(logits, labels):
    return -jnp.mean(jnp.sum(labels * logits, axis=-1))

In [None]:
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics


#### Train step

In [None]:
# Define a training step
def train_step(state, batch):
    def loss_fn(params):
        X, y = batch
        logits = state.apply_fn(params, X)
        loss = -jnp.mean(jnp.sum(logits * y, axis=-1))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch[1])
    return state, metrics

#### JIT

In [None]:
train_step = jax.jit(train_step)

#### Training loop

In [None]:
# Training loop
num_epochs = 20
batch_size = 32

print(f"Expected initial loss: {-jnp.log(1/num_classes)}")

for epoch in range(num_epochs):
    permutation = jax.random.permutation(jax.random.PRNGKey(epoch), len(X_train))
    for i in range(0, len(X_train), batch_size):
        indices = permutation[i:i+batch_size]
        batch = X_train[indices], y_train[indices]
        tx, metrics = train_step(tx, batch)

        print(f"Epoch {epoch}: accuracy = {metrics['accuracy']*100:.2f}%, loss: {metrics['loss']}")

# Evaluate on test set
logits = model.apply(tx.params, X_test)
metrics = compute_metrics(logits, y_test)
print(f"Test accuracy: {metrics['accuracy']*100:.2f}%")

# Questions?