In [1]:
import jax
import torchax

In [2]:
import functools
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
import torch
from torchax import interop

"""

Copy of https://github.com/qihqi/learning_machine/tree/main/torch_pallas

"""

torchax.enable_globally()


def torch_pallas_call(kernel, *args, **kwargs):
  kernel_as_jax = interop.jax_view(kernel)
  orig_pallas_callable = pl.pallas_call(
      kernel_as_jax,
      *args,
      **kwargs,
  )
  return interop.torch_view(orig_pallas_callable)


# https://docs.jax.dev/en/latest/pallas/quickstart.html
# easiest hello world
def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = torch.add(x, y)


  
def add_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  return torch_pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, interop.jax_view(x.dtype)),
      interpret=True
  )(x, y)

print('add vector result', add_vectors(torch.randn(8, device='jax'), torch.randn(8, device='jax')))


# =====  matmul example ===
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
  z_ref[...] = activation(torch.matmul(x_ref[...], y_ref[...]))

def matmul(x: torch.Tensor, y: torch.Tensor, *, activation):
  return torch_pallas_call(
    functools.partial(matmul_kernel, activation=activation),
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), interop.jax_view(x.dtype)),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
    ),
    interpret=True,
  )(x, y)

a = torch.randn((1024, 1024), device='jax')
b = torch.randn((1024, 1024), device='jax')


z = matmul(a, b, activation=torch.nn.functional.relu)
print('matmul result: ', z)


add vector result Tensor(<class 'jaxlib._jax.ArrayImpl'> [-0.53954804  1.0754876  -3.1762674   0.05775923  1.8583001   0.240334
  2.044621   -1.612478  ])
matmul result:  Tensor(<class 'jaxlib._jax.ArrayImpl'> [[ 0.         0.         0.        ...  0.         0.        55.06069  ]
 [31.81658    0.        21.830395  ...  0.        39.03266   13.942827 ]
 [ 0.         0.        36.42938   ...  0.        45.35937   28.471943 ]
 ...
 [ 0.        28.596521   7.1729493 ...  0.         0.         0.       ]
 [42.961147  31.85172    0.        ...  0.        11.14733    0.       ]
 [30.030876   0.        27.99533   ...  0.         9.47911   60.985184 ]])


In [10]:
def simple_kernel(x_ref, y_ref, o_ref):
  print(f'x_ref = {x_ref} ({type(x_ref)})')
  print(f'y_ref = {y_ref} ({type(y_ref)})')
  print(f'o_ref = {o_ref} ({type(o_ref)})')
  x, y = x_ref[...], y_ref[...]
  x, y = x_ref[...], y_ref[...]
  print(f'x = {x} ({type(x)})')
  print(f'y = {y} ({type(y)})')
  o_ref[...] = x + y


In [11]:
lifted = pl.pallas_call(
  simple_kernel,
  out_shape=jax.ShapeDtypeStruct((2, 2), jnp.float32),
  debug=True,
)
jaxpr = jax.make_jaxpr(lifted)(jnp.ones((2, 2)), jnp.ones((2, 2)))
jaxpr

x_ref = Traced<MemRef<None>{float32[2,2]}>with<DynamicJaxprTrace> (<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>)
y_ref = Traced<MemRef<None>{float32[2,2]}>with<DynamicJaxprTrace> (<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>)
o_ref = Traced<MemRef<None>{float32[2,2]}>with<DynamicJaxprTrace> (<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>)
x = Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace> (<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>)
y = Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace> (<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>)


{ lambda ; a:f32[2,2] b:f32[2,2]. let
    c:f32[2,2] = pallas_call[
      backend=None
      compiler_params={}
      cost_estimate=None
      debug=True
      grid_mapping=GridMapping(grid=(), grid_names=None, block_mappings=(BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[0]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[1]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=Shap

In [5]:
inner_jaxpr = jaxpr.eqns[0].params['jaxpr']
inner_jaxpr.eqns

[_:f32[2,2] <- a[:,:],
 _:f32[2,2] <- a[:,:],
 a:f32[2,2] <- b[:,:],
 a:f32[2,2] <- b[:,:],
 a:f32[2,2] = add b c,
 a[:,:] <- b]

In [6]:
for eqn in inner_jaxpr.eqns:
  print(eqn, '#', eqn.primitive, type(eqn.primitive))

_:f32[2,2] <- a[:,:] # get <class 'jax._src.core.Primitive'>
_:f32[2,2] <- a[:,:] # get <class 'jax._src.core.Primitive'>
a:f32[2,2] <- b[:,:] # get <class 'jax._src.core.Primitive'>
a:f32[2,2] <- b[:,:] # get <class 'jax._src.core.Primitive'>
a:f32[2,2] = add b c # add <class 'jax._src.core.Primitive'>
a[:,:] <- b # swap <class 'jax._src.core.Primitive'>


In [7]:
jax.make_jaxpr(torch_pallas_call(
  add_vectors_kernel,
  out_shape=jax.ShapeDtypeStruct((2, 2), jnp.float32),
  debug=True,
))(jnp.ones((2, 2)), jnp.ones((2, 2)))

{ lambda a:f32[2,2]; b:f32[2,2] c:f32[2,2]. let
    _:f32[2,2] = pallas_call[
      backend=None
      compiler_params={}
      cost_estimate=None
      debug=True
      grid_mapping=GridMapping(grid=(), grid_names=None, block_mappings=(BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[0]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[1]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_