In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import matplotlib.pyplot as plt
import time

# Example 1: Basic array operations

In [None]:
def basic_operations():
    # Create arrays
    x = jnp.array([1, 2, 3, 4, 5])
    y = jnp.ones((5,))
    
    # Perform operations
    print("Addition:", x + y)
    print("Matrix multiplication:", jnp.dot(x, y))
    print("Element-wise multiplication:", x * y)

In [None]:
basic_operations()

# Example 2: Automatic differentiation

In [None]:
def square(x):
    return x ** 2

In [None]:
# Get the derivative of the square function
d_square = grad(square)

In [None]:
x = 3.0
print(f"d/dx(x^2) at x = {x} is {d_square(x)}")

# Example 3: Just-in-time compilation

In [None]:
@jit
def jitted_computation(x):
    return jnp.sum(jnp.sin(x) ** 2)

In [None]:
x = jnp.arange(1000000)
    
# Time without JIT
start = time.time()
regular_result = jnp.sum(jnp.sin(x) ** 2)
print(f"Regular time: {time.time() - start:.4f} seconds")

# Time with JIT
start = time.time()
jitted_result = jitted_computation(x)
print(f"JIT time: {time.time() - start:.4f} seconds")

# Example 4: Vectorization

In [None]:
@vmap
def vectorized_square(x):
    return x ** 2

In [None]:
x = jnp.array([1., 2., 3., 4., 5.])
print(f"Vectorized square: {vectorized_square(x)}")

In [None]:
def main():
    #  # Check if GPU is available
    print("Devices available:", jax.devices())
    
    # Simple function to test JAX that returns a scalar for a single input
    @jit
    def f(x):
        return jnp.sin(x) ** 2

    # Get the derivative function (works on scalar inputs)
    df = grad(f)
    
    # Vectorize the derivative function to work on arrays
    df_vectorized = vmap(df)

    # Create data points
    x = jnp.linspace(0, 4 * jnp.pi, 100)
    y = vmap(f)(x)  # Vectorize the original function too
    dy = df_vectorized(x)

    # Plot the results
    plt.figure(figsize=(10, 6))
    plt.plot(x, y, label='f(x) = sin²(x)')
    plt.plot(x, dy, label='f\'(x)')
    plt.grid(True)
    plt.legend()
    plt.title('Function and its derivative using JAX')
    plt.show() 

In [None]:
main()