# Tinyjax

JAX + Tinygrad = Run everywhere

This is a prototype tinygrad backend for JAX. Tinygrad is a simple ML framework
and compiler that supports many devices such as CUDA, OpenCL, Metal, and even
WebGPU and C. JAX is a powerful ML framework that dispatches operations to XLA.
By running those operations using Tinygrad instead, we enjoy the extended device
support from Tinygrad:

- Run JAX on OpenCL (e.g. Intel GPU, AMD GPU without any experimental JAX builds)
- Run JAX on Apple Silicon/Metal without experimental prebuilts
- Compile JAX to WebGL and WebGPU with fused kernels
- Compile JAX to C

In fact this notebook is rendered on an Intel laptop with an iGPU. Jax operations
are running on the CPU while Tinygrad operations are running on the Intel
integrated GPU.

## How does it work?

The [Tinygrad] API builds a lazy computation graph by tracking what APIs are
called on the tensors. When somebody actually needs the data, the graph is JIT
compiled into one or more kernels and scheduled on the device (called "realize").
There are only 25 fundamental operations that everything else lowers into, which
makes it very easy to add new backends.

JAX can turn a Python function into Jaxpr by tracing it with abstract values,
similar to Tinygrad. The resulting Jaxpr is a functional expression language that
is strongly related to XLA.

We implement a Jaxpr interpreter that translates each Jaxpr operation (such as
`dot_general`) into a Tinygrad operation (e.g. `einsum`). Because Tinygrad
operations are lazy, the output of the interpreter is a computation graph instead
of concrete values. And the graph can be JIT compiled into one big GPU kernel via
Tinygrad codegen.

## Current state

Right now enough ops are implemented to convert a ConvNet (see `ops.py`). But it
is very straightforward to add new ops.

## Examples

The rest of the notebook will show some example conversion.

[Tinygrad]: https://github.com/tinygrad/tinygrad

In [1]:
import os
os.environ["JIT"] = "1"

import tinygrad as tg
import jax

tg.Device.DEFAULT, jax.devices()

('GPU', [CpuDevice(id=0)])

In [2]:
import jax.numpy as jnp
import tinygrad as tg
from tinyjax.decorator import tiny

"""
Demonstrate wrapping a JAX function into a Tinygrad function.
"""

@tiny
def fun(first, second):
  temp = first + jnp.sin(second) * 3.0
  return jnp.sum(temp)

# A lazy computation graph will be returned
buf = fun(tg.Tensor([1, 2, 3]), tg.Tensor([4, 5, 6]))
buf

<Tensor <LB GPU () float ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))> on GPU with grad None>

In [3]:
# We can compute (realize) the value using `buf.realize().numpy()`
buf.realize().numpy()

array(0.01457328, dtype=float32)

Now let's try a simple conv net. I didn't use `relu` since the corresponding
higher-order `custom_jvp` JAX primitive isn't implemented in the interpreter yet.

In [4]:
from flax import linen as nn
import jax
import numpy as np

"""Define a CNN and a pure function that runs the CNN."""

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.silu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.silu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.silu(x)
    x = nn.Dense(features=10)(x)
    return x

bs = 512
rng = jax.random.key(42)
rng1, rng2 = jax.random.split(rng)
cnn = CNN()
params = cnn.init(rng1, jnp.ones([1, 28, 28, 1]))
images = jax.random.uniform(rng2, (bs, 28, 28, 1))

def apply_model(params, images):
  """This function takes model params and input images and runs the CNN on it."""
  global cnn
  out = cnn.apply(params, images)
  assert isinstance(out, jax.Array)
  return out

apply_model_jax = jax.jit(apply_model)
apply_model_tinygrad = tiny(apply_model)

In [5]:
import jax.test_util

# Run the jax version.
jax_out = np.array(apply_model_jax(params, images))

# Run the tinygrad version.
tinygrad_params = jax.tree_util.tree_map(lambda p: tg.Tensor(np.array(p)), params)
tinygrad_images = tg.Tensor(np.array(images))
tinygrad_out = apply_model_tinygrad(tinygrad_params, tinygrad_images).numpy()

assert np.allclose(jax_out, tinygrad_out, rtol=1e-3, atol=1e-5)

In [8]:
%timeit apply_model_tinygrad(tinygrad_params, tinygrad_images).realize()

58.3 ms ± 622 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%timeit apply_model_jax(params, images).block_until_ready()

45.9 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
