# JAX Basics

## Importing libraries

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np


## 1. Basic Operations (just like NumPy)

In [None]:
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.array([2.0, 3.0, 4.0, 5.0])
z = x + y
print(f"Addition: {z}")

## 2. Automatic Differentiation

In [None]:
def simple_function(x):
    return x ** 3 + 2 * x ** 2 - 5 * x + 3

# Compute gradient
grad_fn = grad(simple_function)
gradient_at_2 = grad_fn(2.0)
print(f"Gradient at x=2: {gradient_at_2}")


## 3. Just-In-Time Compilation

In [None]:
@jit  # This decorator compiles the function for speed
def fast_computation(x):
    return jnp.dot(x, x.T)

matrix = jnp.ones((1000, 1000))
result = fast_computation(matrix)
print(f"Matrix operation result shape: {result.shape}")

## 4. Vectorization (Auto-batching)

In [None]:
def compute_norm(vector):
    return jnp.sqrt(jnp.sum(vector ** 2))

# Apply to a batch of vectors automatically
vectors = jnp.array([[1.0, 2.0, 3.0], 
                      [4.0, 5.0, 6.0], 
                      [7.0, 8.0, 9.0]])
batched_norm = vmap(compute_norm)
norms = batched_norm(vectors)
print(f"Norms: {norms}")

## 5. Neural Network Example

In [None]:
def simple_model(params, x):
    # Simple one-layer network
    w, b = params
    return jnp.dot(x, w) + b

def loss_fn(params, x, y):
    predictions = simple_model(params, x)
    return jnp.mean((predictions - y) ** 2)

## Initializing parameters

In [None]:
key = jax.random.PRNGKey(0)
w = jax.random.normal(key, (3, 1))
b = jnp.zeros((1,))
params = (w, b)

## Training data

In [None]:
X_train = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y_train = jnp.array([[6.0], [15.0]])


## Compute gradient

In [None]:
grad_fn = grad(loss_fn)
gradients = grad_fn(params, X_train, y_train)
print(f"Loss gradient computed: {len(gradients)} parameter groups")