<a href="https://colab.research.google.com/github/sineeli/jax_series/blob/main/01_framework_benchmarks_jax_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import tensorflow as tf
import jax
import numpy as np

## PRNG Key in Torch, Tensorflow and JAX

- In PyTorch and TensorFlow, setting a seed updates a global state hidden in the background. Every time you generate a random number, that global state is automatically mutated.
- In **JAX** there is no global state which is getting mutated, instead you pass the key explicitly everytime and get the same number.
- To get new random number you split/mutate the existing key and then generate again with new key.


### Torch

In [None]:
torch.manual_seed(42)
print(torch.randn(1)) # Value A
print(torch.randn(1)) # Value B (It remembers the state and generate new random numbers)

tensor([0.3367])
tensor([0.1288])


### Tensorflow

In [None]:
tf.random.set_seed(42)
print(tf.random.normal([1])) # Value A
print(tf.random.normal([1])) # Value B (It remembers the state and generate new random numbers)

tf.Tensor([0.3274685], shape=(1,), dtype=float32)
tf.Tensor([0.08422458], shape=(1,), dtype=float32)


### JAX

In [None]:
key = jax.random.key(42)
print(jax.random.normal(key)) # Value X
print(jax.random.normal(key)) # Value X (Always the same!)

-0.028304616
-0.028304616


#### To generate new numbers here we need to split the key

In [None]:
# Split the master key into two new keys
key, subkey = jax.random.split(key)

# Use the subkey for your random number
print(jax.random.normal(subkey)) # Value Y (New!)

0.60576403


## Now lets jump to comparing the speeds


Notes:

- JAX/TF: They are "Greedy." They pre-allocate a fixed percentage (usually 75% for JAX) of the GPU memory at startup to optimize for speed and avoid the overhead of asking the OS for memory repeatedly.

- PyTorch: It is "Lazy." It allocates memory on-demand. This is why it feels "lighter" initially, but it can lead to fragmentation in long-running training jobs.


In [None]:
size = 3000

In [None]:
import os

# avoids pre allocatin the GPU memory
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

# avoids pre allocatin the GPU memory
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
key = jax.random.key(42)

# lets put jax array in CPU because if we are using GPU it places array direclty in GPU
cpus = jax.devices("cpu")
with jax.default_device(cpus[0]):
    # This is created directly in RAM, GPU is never touched
    x_jnp = jax.random.normal(key, (size, size))

# same case with Tensorflow we need to mention where to place the data
with tf.device('/CPU:0'):
  x_tf = tf.random.normal((size, size))


# torch doesn't do until you explicityly move to cuda device.
x_torch = torch.randn(size, size)

In [None]:
x_jnp.device, x_tf.device, x_torch.device

(CpuDevice(id=0),
 '/job:localhost/replica:0/task:0/device:CPU:0',
 device(type='cpu'))

### CPU

#### JAX

- When you call `jnp.dot(x, y)`, JAX doesn't wait for the CPU/GPU to finish the math. Instead, it immediately returns a `DeviceArray` (a pointer to the future result)

- `block_until_ready()` will wait till the execution completes so you can time it properly

In [None]:
with jax.default_device(cpus[0]):
  %time jax.numpy.dot(x_jnp, x_jnp.T) # this will show compilation time and then you can see the return value later

CPU times: user 187 ms, sys: 16.6 ms, total: 204 ms
Wall time: 130 ms


In [None]:
# and lets the Python thread continue. The actual computation happens in the background on the accelerator.
with jax.default_device(cpus[0]):
  %timeit jax.numpy.dot(x_jnp, x_jnp.T).block_until_ready() # this will wait till the execution also completes

533 ms Â± 10.3 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)


#### Tensorflow

In [None]:
with tf.device('/CPU:0'):
  %timeit tf.matmul(x_tf, x_tf) # here that is the not the case it executes eagerly

599 ms Â± 96.5 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)


#### Torch

In [None]:
%timeit torch.matmul(x_torch, x_torch) # here that is the not the case it executes eagerly

685 ms Â± 298 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)


### GPU
*   **Inputs are Pre-Loaded on Device:**
    The input arrays (`x_jnp`, `x_tf`, `x_torch`) are moved to the GPU memory (VRAM) *before* the timer starts. The overhead of moving data **to** the GPU is **excluded** from the benchmark.
*   **Computation Happens on GPU:**
    The matrix multiplication logic is executed entirely by the GPU cores.
*   **Result is Transferred Back to Host (CPU):**
    By calling `.numpy()` (TF/JAX) or `.cpu()` (PyTorch), you force the resulting tensor to be copied from GPU memory back to CPU RAM.
*   **Implicit Synchronization:**
    Because the CPU cannot access the data until the GPU is finished calculating and transferring it, this forces the CPU to wait. This ensures `%timeit` captures the full duration of the operation, effectively "blocking" the asynchronous nature of the GPU.

In [None]:
x_jnp = jax.random.normal(key, (size, size)) # sits in GPU by default
x_tf = tf.random.normal((size, size)) # sits in GPU by default
if torch.cuda.is_available():
  x_torch = torch.randn(size, size).cuda() # sits in GPU

#### JAX

In [None]:
%timeit np.array(jax.numpy.dot(x_jnp, x_jnp.T).block_until_ready())

36.9 ms Â± 567 Âµs per loop (mean Â± std. dev. of 7 runs, 1 loop each)


#### Tensorflow

- Same as JAX, Tensorflow also dispatchs asynchornously so if you try directly timeit it will keep queueing not the math

In [None]:
%timeit tf.matmul(x_tf, x_tf).numpy()

59.6 ms Â± 7.9 ms per loop (mean Â± std. dev. of 7 runs, 10 loops each)


#### Torch

In [None]:
%timeit torch.matmul(x_torch, x_torch).cpu()

36.5 ms Â± 2.11 ms per loop (mean Â± std. dev. of 7 runs, 10 loops each)


## TPU(Tensor Processing Units)

- JAX Natively works on TPU.
- TPU's are custom made chips from google to train ML models.
- They are just built to calculate large matrix operations, here is the full [page](https://docs.cloud.google.com/tpu/docs/system-architecture-tpu-vm) to study.
- We can run Tensorflow as well on TPU's but it needs a little bit of setup initially in colab.

In [None]:
import jax
import numpy as np

In [None]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]

In [None]:
key = jax.random.key(42)
size = 3000

In [None]:
x_jnp = jax.random.normal(key, (size, size)) # sits in TPU by default

In [None]:
%timeit np.array(jax.numpy.dot(x_jnp, x_jnp.T).block_until_ready()) # matrix is in on CPU

10.7 ms Â± 157 Âµs per loop (mean Â± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit jax_dot(x_jnp).block_until_ready().device # the result is also in TPU

482 Âµs Â± 5.75 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)


## More about JAX Now


1. **JAX Arrays are immutable** - You cannot modify arrays in-place
2. **Functional programming** - JAX relies on pure functions
3. **Different random number generation** - Explicit key-based PRNG
4. **Stateless** - State must be passed explicitly
5. **Accelerator agnostic** - Same code runs on CPU, GPU, or TPU

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)

In [None]:
%timeit selu(x_jnp).block_until_ready()

972 Âµs Â± 6.38 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)


In [None]:
%timeit jax.jit(selu)(x_jnp).block_until_ready() # JIT(Just in Time Compilation)

389 Âµs Â± 9.96 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)


### How JIT Works: Tracing

- JIT works by **tracing** your function. During tracing, JAX replaces actual values with abstract "tracers" that only track shapes and types:

- **Key Insight:** Same shape + same type = reuse cached compiled function!


In [None]:
@jax.jit
def f(x, y):
    print("Running f():")
    print(f" x = {x}")
    print(f" y = {y}")
    result = jax.numpy.dot(x + 1, y + 1)
    print(f" result = {result}")
    return result

x = np.random.randn(3, 4)
y = np.random.randn(4)

In [None]:
print(f(x, y)) # first call traces function

Running f():
 x = JitTracer<float32[3,4]>
 y = JitTracer<float32[4]>
 result = JitTracer<float32[3]>
[8.396521  0.6653428 2.259801 ]


In [None]:
print(f(x, y)) # Second call - uses cached compiled version (no print!)

[8.396521  0.6653428 2.259801 ]


### Viewing the JAX Expression (jaxpr)

In [None]:
def f(x, y):
    return jax.numpy.dot(x + 1, y + 1)

print(jax.make_jaxpr(f)(x, y))

{ [34;1mlambda [39;22m; a[35m:f32[3,4][39m b[35m:f32[4][39m. [34;1mlet
    [39;22mc[35m:f32[3,4][39m = add a 1.0:f32[]
    d[35m:f32[4][39m = add b 1.0:f32[]
    e[35m:f32[3][39m = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  [34;1min [39;22m(e,) }


### JIT Pitfalls

- You can find full sharpbits in jax here more extensive: [ðŸ”ª sharpbits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) which is much more extensive and if I miss something or made a mistake please correct me

####  1. Dynamic Shapes
- JIT requires **static shapes**. Boolean indexing creates dynamic shapes:
- Basically you trying to change the output of the function dynamically based on the boolean value, so this causes the error in `jit` compilation.

In [None]:
def get_negatives(x):
    return x[x < 0]  # Shape depends on values!

x = jax.random.normal(key, (10,))
get_negatives(x)  # Works without JIT

Array([-0.02830462, -0.12403281, -1.4408795 ], dtype=float32)

In [None]:
try:
  jax.jit(get_negatives)(x)
except Exception as NonConcreteBooleanIndexError:
  print(NonConcreteBooleanIndexError)

Array boolean indices must be concrete; got bool[10]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError


#### 2. Value-Dependent Control Flow

- Python tries to execute the if immediately during compilation (tracing).
- It needs to know the value of neg right now to decide which branch to compile.
- But neg is just a placeholder (a Tracer) that doesn't have a value yet.
-Since Python can't decide, it crashes.

In [None]:
@jax.jit
def f(x, neg):
    return -x if neg else x  # Control flow depends on VALUE

f(1, True)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipython-input-2759705761.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

- We can make use of `static_argnames` if that particular doesn't change in training and its not related to data batching flag.

In [None]:
from functools import partial


@jax.jit
@partial(jit, static_argnames=['neg'])
def f(x, neg=True):
    return -x if neg else x  # Control flow depends on VALUE

f(1, True)

#### 3. Using JAX Arrays for Shapes

- You converted the shape (2, 3) into a JAX array.
- JAX treats all JAX arrays as "values that will exist on the GPU later" (Tracers).
- reshape needs to know the exact size right now to allocate memory in the compiled graph. You gave it a "future value" placeholder, so the compiler panics because it can't build a graph with unknown dimensions.

In [None]:
@jit
def f(x):
    # BAD: jnp.array(x.shape) creates a traced value
    return x.reshape(jnp.array(x.shape).prod())
f(jnp.ones((2, 3)))  # ERROR!

In [None]:
@jit
def f(x):
    return x.reshape((np.prod(x.shape),))

f(jnp.ones((2, 3)))  # Works!