# Prototyping

In [1]:
import jax

In [2]:
def f(x):
  return jax.numpy.sin(jax.numpy.cos(x))

print(f(3.0))

-0.83602184


In [3]:
jaxpr = jax.make_jaxpr(f)(3.0)
print(jaxpr)

{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }


In [4]:
from jax import make_jaxpr
import jax.numpy as jnp


def func1(first, second):
  temp = first + jnp.sin(second) * 3.0
  return jnp.sum(temp)


print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))

{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }


In [5]:
from jax import lax, grad, vmap


def one_of_three(index, arg):
  a = jnp.array([1, 2, 3, 4, 5, 6], dtype=jnp.float32)
  return lax.switch(
    index, [lambda x: x + 1.0, lambda x: x - 2.0, lambda x: x + 3.0], arg
  ) + jnp.sum(a)


func2 = vmap(grad(one_of_three, argnums=1))
print(make_jaxpr(func2)(jnp.zeros(2, dtype=jnp.int32) + 1, jnp.zeros(2) + 5))

{ lambda a:f32[6]; b:i32[2] c:f32[2]. let
    d:i32[2] = clamp 0 b 2
    e:bool[2] = eq d 0
    f:f32[2] = stop_gradient c
    g:f32[2] = select_n e f c
    h:f32[2] = add g 1.0
    i:bool[2] = eq d 1
    j:f32[2] = stop_gradient c
    k:f32[2] = select_n i j c
    l:f32[2] = sub k 2.0
    m:bool[2] = eq d 2
    n:f32[2] = stop_gradient c
    o:f32[2] = select_n m n c
    p:f32[2] = add o 3.0
    q:f32[2] = select_n d h l p
    r:f32[] = reduce_sum[axes=(0,)] a
    _:f32[2] = add q r
    s:f32[2] = broadcast_in_dim[broadcast_dimensions=() shape=(2,)] 1.0
    t:bool[2] = eq d 0
    u:f32[2] = stop_gradient s
    v:f32[2] = select_n t u s
    w:bool[2] = eq d 1
    x:f32[2] = stop_gradient s
    y:f32[2] = select_n w x s
    z:bool[2] = eq d 2
    ba:f32[2] = stop_gradient s
    bb:f32[2] = select_n z ba s
    bc:f32[2] = select_n d v y bb
  in (bc,) }


In [6]:
import numpy as np
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map

In [7]:
def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.consts)

{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]


In [8]:
from jax.core import Jaxpr

def eval_jaxpr(jaxpr: Jaxpr, consts, *args, debug: bool = False):
  assert type(debug) == bool

  # Mapping from variable -> value
  env = {}

  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val
    if debug:
      print(f"[JAX] Writing to {var}: {val}")

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)
    if debug:
      print(f"Processing {eqn}")
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results:
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals)
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars)

In [9]:
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, jnp.ones(5)), f(jnp.ones(5))

([Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)],
 Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32))

In [10]:
closed_jaxpr = make_jaxpr(func2)(jnp.zeros(2, dtype=jnp.int32) + 1, jnp.zeros(2) + 5)
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, jnp.zeros(2, dtype=jnp.int32) + 1, jnp.zeros(2) + 5), func2(jnp.zeros(2, dtype=jnp.int32) + 1, jnp.zeros(2) + 5)

([Array([1., 1.], dtype=float32)], Array([1., 1.], dtype=float32))

## Example CNN and its jaxpr

In [11]:
from flax import linen as nn
import jax
import jax.numpy as jnp

In [12]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.silu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.silu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.silu(x)
    x = nn.Dense(features=10)(x)
    return x


In [13]:
def apply_model(cnn: CNN, params, images):
  return cnn.apply(params, images)

rng = jax.random.key(0)
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))

In [14]:
bs = 512
images = jnp.ones((bs, 28, 28, 1))

def partial_apply(images):
  return apply_model(cnn, params, images)

partial_apply(images)

Array([[-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726],
       [-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726],
       [-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726],
       ...,
       [-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726],
       [-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726],
       [-0.03884823,  0.01811368,  0.05547812, ..., -0.05636393,
        -0.00622258,  0.01677726]], dtype=float32)

In [15]:
closed_jaxpr = make_jaxpr(partial_apply)(images)

In [16]:
from jax.interpreters.partial_eval import dce_jaxpr

opt_jaxpr, _ = dce_jaxpr(closed_jaxpr.jaxpr, [True])
opt_jaxpr

{ lambda a:f32[3,3,1,32] b:f32[32] c:f32[3,3,32,64] d:f32[64] e:f32[3136,256] f:f32[256]
    g:f32[256,10] h:f32[10]; i:f32[512,28,28,1]. let
    j:f32[512,28,28,32] = conv_general_dilated[
      batch_group_count=1
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
      feature_group_count=1
      lhs_dilation=(1, 1)
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] i a
    k:f32[1,1,1,32] = reshape[dimensions=None new_sizes=(1, 1, 1, 32)] b
    l:f32[512,28,28,32] = add j k
    m:f32[512,28,28,32] = pjit[
      name=silu
      jaxpr={ lambda ; n:f32[512,28,28,32]. let
          o:f32[512,28,28,32] = logistic n
          p:f32[512,28,28,32] = mul n o
        in (p,) }
    ] l
    q:f32[512,14,14,32] = reduce_window_sum[
      base_dilation=(1, 1, 1, 1)
      padding=((0, 0), (0, 0), (0, 0), (0, 0))
      window_dilation=(1, 1, 

In [17]:
assert jnp.all(eval_jaxpr(opt_jaxpr, closed_jaxpr.literals, images)[0] == partial_apply(images))

## Find all operations in this jaxpr

In [18]:
from jax._src.core import AxisPrimitive

def find_ops(jaxpr):
  prims = set()

  def add_prims(prims, jaxpr):
    for eqn in jaxpr.eqns:
      match eqn.primitive:
        case AxisPrimitive(name="pjit") | AxisPrimitive(name="jit"):
          # Recursively process sub-jaxpr
          add_prims(prims, eqn.params["jaxpr"])
        case _:
          prims.add(eqn.primitive)

  add_prims(prims, jaxpr)
  return prims

prims = find_ops(opt_jaxpr)
for p in prims:
  print(p)

add
div
conv_general_dilated
mul
reshape
dot_general
logistic
convert_element_type
reduce_window_sum


## Toy numpy interpreter

In [19]:
def func1(first, second):
  temp = first + jnp.sin(second) * 3.0
  return jnp.sum(temp)

closed_jaxpr = make_jaxpr(func1)(5, jnp.array([1, 2, 3]))
find_ops(closed_jaxpr)

{add, convert_element_type, mul, reduce_sum, sin}

In [20]:
from typing import Any, Sequence
from jax.core import Jaxpr
from jax.typing import DTypeLike
from jax import lax

np_registry = {}

def np_convert_element_type(operand: Any, new_dtype: DTypeLike, weak_type: bool = False) -> np.ndarray:
  match operand:
    case int() | float() | np.floating():
      return np.array(operand).astype(new_dtype)
    case _:
      return operand.astype(new_dtype)
np_registry[lax.convert_element_type_p] = np_convert_element_type

def np_sin(x: np.ndarray) -> np.ndarray:
  return np.sin(x)
np_registry[lax.sin_p] = np_sin

def np_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
  return np.multiply(a, b)
np_registry[lax.mul_p] = np_mul

def np_add(a: np.ndarray, b: np.ndarray) -> np.ndarray:
  return np.add(a, b)
np_registry[lax.add_p] = np_add

def np_reduce_sum(operand: np.ndarray, axes: Sequence[int]) -> np.ndarray:
  return np.sum(operand, axes)
np_registry[lax.reduce_sum_p] = np_reduce_sum

def eval_jaxpr_np(jaxpr: Jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}

  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    match val:
      case int() | float() | np.floating():
        env[var] = np.array(val)
      case jax.Array():
        env[var] = np.array(val)
      case _:
        env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)
    if eqn.primitive not in np_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have an implementation")
    outvals = np_registry[eqn.primitive](*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results:
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals)
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars)

In [21]:
print(eval_jaxpr_np(closed_jaxpr.jaxpr, closed_jaxpr.consts, 5, jnp.array([1, 2, 3])))
print(eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, 5, jnp.array([1, 2, 3])))

[array(20.675665, dtype=float32)]


[Array(20.675665, dtype=float32)]


## Tinygrad interpreter

In [22]:
import tinygrad as tg
import tinygrad.dtype
import tinygrad.function as tgf
from typing import Any, Sequence, Tuple
from jax.core import Jaxpr, ShapedArray, ClosedJaxpr
from jax.typing import DTypeLike
from jax import lax
from jax._src.typing import Shape
from jax._src.pjit import pjit_p

tg_registry = {}

# mul
def tg_mul(a: tg.Tensor, b: tg.Tensor) -> tg.Tensor:
  return a.mul(b)
tg_registry[lax.mul_p] = tg_mul

# reshape
def tg_reshape(operand: tg.Tensor, new_sizes: Shape, dimensions: Sequence[int] | None = None) -> tg.Tensor:
  shape = list(int(v) for v in new_sizes)
  if dimensions:
    permuted_shape = [0] * len(dimensions)
    for i, d in enumerate(dimensions):
      permuted_shape[d] = shape[i]
    shape = permuted_shape
  return operand.reshape(shape)
tg_registry[lax.reshape_p] = tg_reshape

# div
def tg_div(a: tg.Tensor, b: tg.Tensor) -> tg.Tensor:
  return a.div(b)
tg_registry[lax.div_p] = tg_div

# reduce_window_sum
def tg_reduce_window_sum(operand: tg.Tensor, window_dimensions: Shape,
                         window_strides: Sequence[int],
                         padding: Sequence[tuple[int, int]],
                         base_dilation: Sequence[int] | None = None,
                         window_dilation: Sequence[int] | None = None) -> tg.Tensor:
  # a:f32[3,14,14,32] = reduce_window_sum[
  #   base_dilation=(1, 1, 1, 1)
  #   padding=((0, 0), (0, 0), (0, 0), (0, 0))
  #   window_dilation=(1, 1, 1, 1)
  #   window_dimensions=(1, 2, 2, 1)
  #   window_strides=(1, 2, 2, 1)
  # ]
  assert padding == ((0, 0), (0, 0), (0, 0), (0, 0))
  assert base_dilation == (1, 1, 1, 1)
  window_dilation = window_dilation or [1, 1, 1, 1]
  pooled = operand._pool(
    k_=tuple(window_dimensions), stride=tuple(window_strides), dilation=tuple(window_dilation))
  return pooled.sum(axis=tuple(range(0-len(window_dimensions), 0)))
tg_registry[lax.reduce_window_sum_p] = tg_reduce_window_sum

# reduce_sum
def tg_reduce_sum(operand: tg.Tensor, axes: Sequence[int]) -> tg.Tensor:
  return operand._reduce(tgf.Sum, tuple(axes))
tg_registry[lax.reduce_sum_p] = tg_reduce_sum

# convert_element_type
def tg_convert_element_type(operand: Any, new_dtype: DTypeLike, weak_type: bool = False) -> tg.Tensor:
  match operand:
    case int() | float() | np.floating() | np.integer():
      return tg.Tensor(np.array(operand).astype(new_dtype))
    case tg.Tensor():
      return operand.cast(convert_dtype(np.dtype(new_dtype)))
    case _:
      raise RuntimeError(f"Unsupported operand type {type(operand)}")
def convert_dtype(dtype: np.dtype) -> tinygrad.dtype.DType:
  match dtype:
    case np.float16:
      return tinygrad.dtypes.float16
    case np.float32:
      return tinygrad.dtypes.float32
    case np.int32:
      return tinygrad.dtypes.int32
    case np.int64:
      return tinygrad.dtypes.int64
    case _:
      raise RuntimeError(f"Unsupported dtype f{dtype}")
tg_registry[lax.convert_element_type_p] = tg_convert_element_type

# add
def tg_add(a: tg.Tensor, b: tg.Tensor) -> tg.Tensor:
  return a.add(b)
tg_registry[lax.add_p] = tg_add

# logistic
def tg_logistic(x: tg.Tensor) -> tg.Tensor:
  return x.sigmoid()
tg_registry[lax.logistic_p] = tg_logistic

# sin
def tg_sin(x: tg.Tensor) -> tg.Tensor:
  return x.sin()
tg_registry[lax.sin_p] = tg_sin

# dot_general
def tg_dot_general(lhs: tg.Tensor, rhs: tg.Tensor, dimension_numbers: lax.DotDimensionNumbers,
                   precision: lax.PrecisionLike = None,
                   preferred_element_type: DTypeLike | None = None) -> tg.Tensor:
  # a:f32[3,256] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b c
  assert precision is None
  assert preferred_element_type is None

  ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) = dimension_numbers
  # Generate a set of unique labels for einsum
  labels = 'abcdefghijklmnopqrstuvwxyz'

  # Assign labels to each dimension of lhs and rhs
  lhs_labels = [''] * len(lhs.shape)
  rhs_labels = [''] * len(rhs.shape)

  # Assign labels for contracting dimensions
  for i, (lhs_dim, rhs_dim) in enumerate(zip(lhs_contracting_dims, rhs_contracting_dims)):
    label = labels[i]
    lhs_labels[lhs_dim] = label
    rhs_labels[rhs_dim] = label

  # Start index for batch and other dimensions
  start_idx = len(lhs_contracting_dims)

  # Assign labels for batch dimensions
  for i, (lhs_dim, rhs_dim) in enumerate(zip(lhs_batch_dims, rhs_batch_dims), start=start_idx):
    label = labels[i]
    lhs_labels[lhs_dim] = label
    rhs_labels[rhs_dim] = label

  # Assign labels for remaining dimensions
  for i, label in enumerate(lhs_labels):
    if label == '':
      lhs_labels[i] = labels[start_idx]
      start_idx += 1

  for i, label in enumerate(rhs_labels):
    if label == '':
      rhs_labels[i] = labels[start_idx]
      start_idx += 1

  # Construct the einsum string
  lhs_subscripts = ''.join(lhs_labels)
  rhs_subscripts = ''.join(rhs_labels)
  result_subscripts = ''.join([label for label in lhs_labels + rhs_labels if label not in lhs_labels or label not in rhs_labels])

  formula = f'{lhs_subscripts},{rhs_subscripts}->{result_subscripts}'
  return tg.Tensor.einsum(formula, lhs, rhs)
tg_registry[lax.dot_general_p] = tg_dot_general

# conv_general_dilated
def tg_conv_general_dilated(
  lhs: tg.Tensor, rhs: tg.Tensor, window_strides: Sequence[int],
  padding: str | Sequence[tuple[int, int]],
  lhs_dilation: Sequence[int] | None = None,
  rhs_dilation: Sequence[int] | None = None,
  dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers  = None,
  feature_group_count: int = 1, batch_group_count: int = 1,
  precision: lax.PrecisionLike = None,
  preferred_element_type: DTypeLike | None = None
) -> tg.Tensor:
  # t:f32[3,14,14,64] = conv_general_dilated[
  #    batch_group_count=1
  #    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
  #    feature_group_count=1
  #    lhs_dilation=(1, 1)
  #    padding=((1, 1), (1, 1))
  #    precision=None
  #    preferred_element_type=None
  #    rhs_dilation=(1, 1)
  #    window_strides=(1, 1)
  #  ]
  dim_nums = lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
  lhs = lhs.permute(dim_nums.lhs_spec)
  rhs = rhs.permute(dim_nums.rhs_spec)
  assert batch_group_count == 1
  assert feature_group_count == 1
  assert precision is None
  assert preferred_element_type is None
  assert lhs_dilation == (1, 1)
  assert not isinstance(padding, str)
  assert len(padding) == 2
  assert padding[0][0] == padding[0][1]
  assert padding[1][0] == padding[1][1]
  tg_padding = (int(padding[0][0]), int(padding[1][0]))
  assert window_strides == (1, 1)
  result = lhs.conv2d(rhs, dilation=tuple(int(x) for x in rhs_dilation), padding=tg_padding) # type: ignore
  return permute_to_spec(result, existing=(0, 2, 3, 1), desired=dim_nums.out_spec)
def permute_to_spec(v: tg.Tensor, existing: Sequence[int], desired: Sequence[int]) -> tg.Tensor:
  assert all(0 <= i <= 3 for i in existing)
  assert all(0 <= i <= 3 for i in desired)
  permute_idx = tuple(int(existing.index(i)) for i in desired)
  return v.permute(permute_idx)
tg_registry[lax.conv_general_dilated_p] = tg_conv_general_dilated

# pjit
def tg_pjit(*inputs, jaxpr: ClosedJaxpr, **kwargs):
  return eval_jaxpr_tg(jaxpr.jaxpr, jaxpr.consts, *inputs)
tg_registry[pjit_p] = tg_pjit

def eval_jaxpr_tg(jaxpr: Jaxpr, consts, *args, debug: bool = False):
  assert type(debug) == bool

  # Mapping from variable -> value
  env = {}

  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return tg.Tensor(np.array(var.val))
    return env[var]

  def write(var, val):
    match val:
      case int() | float() | np.floating() | np.integer():
        env[var] = tg.Tensor(np.array(val))
      case jax.Array():
        env[var] = tg.Tensor(np.array(val))
      case _:
        env[var] = val
    if debug:
      print(f"Writing to {var}: {env[var].numpy()}")

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)
    if eqn.primitive not in tg_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have an implementation: {eqn}")
    if debug:
      print(f"Processing {eqn}")
    outvals = tg_registry[eqn.primitive](*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results:
      outvals = [outvals]
    for var, val in zip(eqn.outvars, outvals):
      if isinstance(var.aval, ShapedArray):
        assert var.aval.shape == val.shape, f"{var.aval.shape} != {val.shape}"
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals)

  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars)

In [23]:
def func1(first, second):
  temp = first + jnp.sin(second) * 3.0
  return jnp.sum(temp)

closed_jaxpr = make_jaxpr(func1)(5, jnp.array([1, 2, 3]))
find_ops(closed_jaxpr)

print(list(map(lambda x: x.numpy(), eval_jaxpr_tg(closed_jaxpr.jaxpr, closed_jaxpr.consts, 5, jnp.array([1, 2, 3])))))
print(eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, 5, jnp.array([1, 2, 3])))

[array(20.675665, dtype=float32)]
[Array(20.675665, dtype=float32)]


In [24]:
closed_jaxpr = make_jaxpr(partial_apply)(images)
opt_jaxpr, _ = dce_jaxpr(closed_jaxpr.jaxpr, [True])

In [25]:
tg_out = eval_jaxpr_tg(opt_jaxpr, closed_jaxpr.literals, images)[0].numpy()
jax_out = np.array(eval_jaxpr(opt_jaxpr, closed_jaxpr.literals, images)[0])
print("tinygrad: ", tg_out)
print("jax: ", jax_out)
assert np.allclose(tg_out, jax_out, rtol=1e-2, atol=1e-4)

tinygrad:  [[-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]
 [-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]
 [-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]
 ...
 [-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]
 [-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]
 [-0.03884818  0.01811363  0.05547819 ... -0.05636396 -0.00622255
   0.01677731]]
jax:  [[-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]
 [-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]
 [-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]
 ...
 [-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]
 [-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]
 [-0.03884823  0.01811368  0.05547812 ... -0.05636393 -0.00622258
   0.01677726]]

In [26]:
from tinygrad import TinyJit

@TinyJit
def eval_tg(images):
  return eval_jaxpr_tg(opt_jaxpr, closed_jaxpr.literals, images)[0].realize()

@jax.jit
def eval_jax(images):
  return eval_jaxpr(opt_jaxpr, closed_jaxpr.literals, images)[0]

random_images = [ np.random.rand(bs, 28, 28, 1).astype(np.float32) for _ in range(100) ]
random_images_tg = [ tg.Tensor(x) for x in random_images ]
random_images_jax = [ jax.numpy.array(x) for x in random_images ]

def run_tg():
  return [eval_tg(x).numpy() for x in random_images_tg]

def run_jax():
  return [np.array(eval_jax(x)) for x in random_images_jax]

## Tinygrad on OpenCL/GPU vs JAX on CPU

CPU: 11th Gen Intel i7-11850H (16) @ 4.800GHz

GPU: Intel TigerLake-H GT1 [UHD Graphics]

In [27]:
%timeit run_tg()

5.51 s ± 121 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
%timeit run_jax()

8.07 s ± 56.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
