**Setup**

Here we import the required libraries.

Then we set the seed so the results are similar between runs.**bold text**

In [32]:
# Import necessary libraries
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import random, jit
import timeit

# Set random seed for reproducibility
np.random.seed(42)
key = random.PRNGKey(42)

Define the dataset

In [33]:
# Generate random data
size = 1000000
data = {
    'A': np.random.rand(size),
    'B': np.random.randn(size),
    'C': np.random.randint(0, 2, size),
}

Perform the numpy operation

In [34]:
# NumPy operations
def numpy_operations(data):
    a = np.sum(data['A'])
    b = np.mean(data['B'])
    c = np.max(data['C'])
    return a, b, c

# Time the operations and compare results
numpy_time = timeit.timeit(lambda: numpy_operations(data), number=100)

Perform the same operation on Pandas

In [35]:
# Pandas operations
def pandas_operations(data):
    df = pd.DataFrame(data)
    a = df['A'].sum()
    b = df['B'].mean()
    c = df['C'].max()
    return a, b, c

pandas_time = timeit.timeit(lambda: pandas_operations(data), number=100)

Perform the same operation on JAX

In [36]:
# JAX operations
def jax_operations(data):
    a = jnp.sum(data['A'])
    b = jnp.mean(data['B'])
    c = jnp.max(data['C'])
    return a, b, c
jax_time = timeit.timeit(lambda: jax_operations(data), number=100)


In [37]:
# Display timing results
time_df = pd.DataFrame({
    'Operation': ['Total time'],
    'NumPy': numpy_time,
    'Pandas': pandas_time,
    'JAX': jax_time,
})

time_df

Unnamed: 0,Operation,NumPy,Pandas,JAX
0,Total time,0.182291,0.875292,0.430037


Now let's look at how JIT compares

In [38]:
@jit
def jax_jit_operations(data):
    a = jnp.sum(data['A'])
    b = jnp.mean(data['B'])
    c = jnp.max(data['C'])
    return a, b, c

jax_jit_time = timeit.timeit(lambda: jax_jit_operations(data), number=100)

In [39]:
# Display timing results
time_df = pd.DataFrame({
    'Operation': ['Total time'],
    'NumPy': numpy_time,
    'Pandas': pandas_time,
    'JAX': jax_time,
    'JAX with JIT': jax_jit_time,
})

time_df

Unnamed: 0,Operation,NumPy,Pandas,JAX,JAX with JIT
0,Total time,0.182291,0.875292,0.430037,0.784259
