In [2]:
import jax
import torchax

In [29]:
import functools
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
import torch
from torchax import interop

"""

Copy of https://github.com/qihqi/learning_machine/tree/main/torch_pallas

"""

torchax.enable_globally()


def torch_pallas_call(kernel, *args, **kwargs):
  kernel_as_jax = interop.jax_view(kernel)
  orig_pallas_callable = pl.pallas_call(
      kernel_as_jax,
      *args,
      **kwargs,
  )
  return interop.torch_view(orig_pallas_callable)


# https://docs.jax.dev/en/latest/pallas/quickstart.html
# easiest hello world
def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = torch.add(x, y)


  
def add_vectors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  return torch_pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, interop.jax_view(x.dtype)),
      interpret=True
  )(x, y)

print('add vector result', add_vectors(torch.randn(8, device='jax'), torch.randn(8, device='jax')))


# =====  matmul example ===
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
  z_ref[...] = activation(torch.matmul(x_ref[...], y_ref[...]))

def matmul(x: torch.Tensor, y: torch.Tensor, *, activation):
  return torch_pallas_call(
    functools.partial(matmul_kernel, activation=activation),
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), interop.jax_view(x.dtype)),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
    ),
    interpret=True,
  )(x, y)

a = torch.randn((1024, 1024), device='jax')
b = torch.randn((1024, 1024), device='jax')


z = matmul(a, b, activation=torch.nn.functional.relu)
print('matmul result: ', z)


add vector result Tensor(<class 'jaxlib._jax.ArrayImpl'> [ 0.67377186  0.07321142 -2.9761477   1.3824334  -2.3937879   0.12552863
 -0.53504986 -0.64411575])
matmul result:  Tensor(<class 'jaxlib._jax.ArrayImpl'> [[ 0.        15.863418   9.280289  ...  5.2803307  0.         0.       ]
 [ 9.733851   0.         5.2965965 ...  0.         0.        54.980007 ]
 [ 0.         0.        31.144966  ...  0.        18.25214   43.20775  ]
 ...
 [ 0.        10.782309  25.80857   ...  0.        13.586363  58.227524 ]
 [23.757692   0.         0.        ... 38.23269   11.854513   0.       ]
 [ 7.3983526 32.635433  13.593995  ... 28.632645  40.917164   0.       ]])


In [2]:
def simple_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y


In [6]:
lifted = pl.pallas_call(
  simple_kernel,
  out_shape=jax.ShapeDtypeStruct((2, 2), jnp.float32),
  debug=True,
)
jaxpr = jax.make_jaxpr(lifted)(jnp.ones((2, 2)), jnp.ones((2, 2)))
jaxpr

{ lambda ; a:f32[2,2] b:f32[2,2]. let
    c:f32[2,2] = pallas_call[
      backend=None
      compiler_params={}
      cost_estimate=None
      debug=True
      grid_mapping=GridMapping(grid=(), grid_names=None, block_mappings=(BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[0]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=ShapeDtypeStruct(shape=(2, 2), dtype=float32), origin='args[1]', transforms=(), pipeline_mode=None), BlockMapping(block_shape=(Blocked(block_size=2), Blocked(block_size=2)), transformed_block_aval=MemRef<None>{float32[2,2]}, index_map_jaxpr={ lambda ; . let  in (0, 0) }, array_shape_dtype=Shap

In [22]:
inner_jaxpr = jaxpr.eqns[0].params['jaxpr']
inner_jaxpr.eqns

[a:f32[2,2] <- b[:,:], a:f32[2,2] <- b[:,:], a:f32[2,2] = add b c, a[:,:] <- b]

In [28]:
for eqn in inner_jaxpr.eqns:
  print(eqn, '#', eqn.primitive, type(eqn.primitive))

a:f32[2,2] <- b[:,:] # get <class 'jax._src.core.Primitive'>
a:f32[2,2] <- b[:,:] # get <class 'jax._src.core.Primitive'>
a:f32[2,2] = add b c # add <class 'jax._src.core.Primitive'>
a[:,:] <- b # swap <class 'jax._src.core.Primitive'>
