# Experimenting converting PyTorch code to Jax code

Approach 1: let's print the jaxpr back into Python

In [1]:
%pip install --upgrade pip
%pip install --upgrade JaxDecompiler

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import jax
import torch
import torchax
import jax.extend.core as core
from JaxDecompiler import decompiler
from torchax.interop import jax_view

torchax.enable_globally()

In [3]:
# Example torch function
def my_mathy_thing(weight, hidden):
  def other_math(a, b):
    return (a + b) @ (a * b)

  activation = torch.matmul(weight, hidden)
  output = torch.sin(activation) + other_math(weight, hidden)
  return output + hidden

In [4]:
class NameIterator:
  """
  An iterator that yields strings in lexicographical order:
  'a', 'b', ..., 'z', 'aa', 'ab', ..., 'az', 'ba', ...

  The sequence starts with single letters, then moves to longer
  combinations, following the pattern similar to spreadsheet columns.
  """

  def __init__(self):
    self.alphabet = "abcdefghijklmnopqrstuvwxyz"
    self.current = None

  def __iter__(self):
    return self

  def __next__(self):
    if self.current is None:
      self.current = "a"
      return self.current

    # Convert the current string to a list of indices
    indices = [self.alphabet.index(c) for c in self.current]

    # Try to increment the rightmost index
    pos = len(indices) - 1
    indices[pos] += 1

    # Handle carry, like incrementing digits in a number
    while pos >= 0 and indices[pos] == 26:
      indices[pos] = 0
      pos -= 1

      if pos < 0:
        # We need to add a new position (like going from 'z' to 'aa')
        indices = [0] * (len(indices) + 1)
        break

      indices[pos] += 1

    # Convert back to a string
    self.current = "".join(self.alphabet[i] for i in indices)
    return self.current


class LegalNameIterator:
  def __init__(self):
    self.iterator = NameIterator()

  def __iter__(self):
    return self

  def __next__(self):
    while True:
      name = next(self.iterator)
      if name not in {'and', 'as', 'not', 'or', 'is'}:
        return name



def decompile_jaxpr(jaxpr: core.ClosedJaxpr) -> str:
  names = LegalNameIterator()

  class NamedVar(core.Var):
    def __init__(self, name, suffix, aval):
      self.name = name
      super().__init__(suffix, aval)

    def __str__(self):
      return self.name

    def __repr__(self):
      if self.suffix:
        return f"var_id_{id(self)}_{self.suffix}"
      return f"var_id_{id(self)}"

  var_names = {}
  def get_name(v: core.Var):
    key = v.__repr__()
    if key not in var_names:
      var_names[key] = next(names)
    return var_names[key]

  def name(vars):
    for i in range(len(vars)):
      v = vars[i]
      if not isinstance(v, NamedVar) and isinstance(v, core.Var):
        count = v.count
        var_name = get_name(v)
        nv = NamedVar(var_name, v.suffix, v.aval)
        setattr(nv, 'count', count)
        vars[i] = nv

  # Create an in-memory file object
  import io
  f = io.StringIO()
  # Name all the variables
  def recursive_name(jaxpr):
    name(jaxpr.invars)
    for eqn in jaxpr.eqns:
        name(eqn.invars)
        name(eqn.outvars)
        if hasattr(eqn, 'params') and 'jaxpr' in eqn.params:
          recursive_name(eqn.params['jaxpr'].jaxpr)
    name(jaxpr.outvars)
  recursive_name(jaxpr.jaxpr)
  python_lines = decompiler.decompiler(jaxpr, python_func_name='decompiled')
  decompiler._recursively_write_python_program(f, python_lines)
  f.seek(0)
  s = f.read()
  f.close()
  try:
    from yapf.yapflib.yapf_api import FormatCode
    formatted_code, _ = FormatCode(s)
  except Exception as e:
    print("yapf error")
    print(s)
    raise e
  return formatted_code

In [5]:
my_mathy_thing_jaxpr = jax.make_jaxpr(jax_view(my_mathy_thing))(jax.numpy.ones((4, 4)), jax.numpy.ones((4, 4)))
my_mathy_thing_py = decompile_jaxpr(my_mathy_thing_jaxpr)
print(my_mathy_thing_py)

import jax
from jax.numpy import *
from jax.experimental import sparse
from jax._src import prng
from mpi4py import MPI


def decompiled(a, b):
    c = tensordot(a, b, axes=((1, ), (0, )))
    d = sin(c)
    e = b * 1.0
    f = a + e
    g = a * b
    h = tensordot(f, g, axes=((1, ), (0, )))
    i = h * 1.0
    j = d + i
    k = b * 1.0
    l = j + k
    return l



This works but has significant readability drawbacks:

- Function invocations are all inlined and flattened.
- Variable names don't have meaning.

Let's try this on a Llama transformer block.

In [6]:
import sys
from pathlib import Path
import torchax.interop

# Adjust path as needed
p = Path('/') / 'workspaces' / 'torch' / 'pytorch' / 'xla' / 'torchax' / 'test' / 'llama'
assert p.exists()
sys.path.append(str(p))


def setup_llama():
  import llama_model  # type: ignore

  model_args = llama_model.ModelArgs(
      block_size=2048,
      vocab_size=32000,
      n_layer=2,
      n_head=4,
      dim=256,
  )
  m = llama_model.Transformer(model_args)
  m.to(torch.bfloat16)
  m.setup_caches(1, 2048)
  m = m.to('jax')

  # Extract jaxpr
  input_pos = torch.arange(0, 2048, device='jax:0')
  sample_args = (
      torch.rand((1, 2048, 256), device='jax:0'),  # Embedding
      input_pos,  # Input pos
      m.freqs_cis[input_pos],  # Freqs-cis
      m.causal_mask[None, None, input_pos],  # Mask
  )
  states, jax_func = torchax.extract_jax(m.layers[0])
  sample_inputs = jax_view(sample_args)
  jaxpr = jax.make_jaxpr(jax_func)(states, sample_inputs)
  return jaxpr, sample_inputs


jaxpr, sample_inputs = setup_llama()

q= (1, 4, 2048, 64)
k= (1, 4, 2048, 64)
v= (1, 4, 2048, 64)
mask= (1, 1, 2048, 2048)




We also need to add some extra op support to JaxDecompiler.

In [None]:
from JaxDecompiler import primitive_mapping # Prod phase


def register(fn):
  setattr(primitive_mapping, fn.__name__, fn)


@register
def rsqrt(input_var, output_var, params):
  return f"{output_var[0]} = jax.lax.rsqrt({input_var[0]})"


@register
def scatter(input_var, output_var, params):
  operand = input_var[0]
  scatter_indices = input_var[1]
  updates = input_var[2]

  dimension_numbers = params.get("dimension_numbers", None)
  indices_are_sorted = params.get("indices_are_sorted", False)
  unique_indices = params.get("unique_indices", False)

  options = []
  if dimension_numbers:
    options.append(f"dimension_numbers=jax.lax.{dimension_numbers}")
  if indices_are_sorted:
    options.append(f"indices_are_sorted={indices_are_sorted}")
  if unique_indices:
    options.append(f"unique_indices={unique_indices}")

  options_str = ", ".join(options)
  if options_str:
    options_str = ", " + options_str

  return f"{output_var[0]} = jax.lax.scatter({operand}, {scatter_indices}, {updates}{options_str})"


@register
def not__(input_var, output_var, params):
  return f"{output_var[0]} = ~{input_var[0]}"


@register
def logistic(input_var, output_var, params):
    return f"{output_var[0]} = 1.0 / (1.0 + exp(-{input_var[0]}))"


@register
def pjit(input_var, output_var, params) -> list[str]:
  local_f_name = "local_f" + str(primitive_mapping._LOCAL_F_COUNT)
  primitive_mapping._LOCAL_F_COUNT += 1

  lvalue = ", ".join(output_var)
  rvalue = ", ".join(input_var)

  # pjit specific parameters
  in_axis_resources = params.get("in_axis_resources", None)
  out_axis_resources = params.get("out_axis_resources", None)
  resource_env = params.get("resource_env", None)
  donated_invars = params.get("donated_invars", None)

  options = []
  if in_axis_resources is not None:
      options.append(f"in_axis_resources={in_axis_resources}")
  if out_axis_resources is not None:
      options.append(f"out_axis_resources={out_axis_resources}")
  if resource_env is not None:
      options.append(f"resource_env={resource_env}")
  if donated_invars is not None:
      options.append(f"donated_invars={donated_invars}")

  options_str = ", ".join(options)
  if options_str:
      options_str = ", " + options_str

  line = f"{lvalue} = jax.pjit({local_f_name}{options_str})({rvalue})"

  params["call_jaxpr"] = params["jaxpr"].jaxpr
  lines = primitive_mapping._recurive_op(params, line, local_f_name)
  return lines


In [8]:
llama_py = decompile_jaxpr(jaxpr)
print(llama_py)

import jax
from jax.numpy import *
from jax.experimental import sparse
from jax._src import prng
from mpi4py import MPI


def decompiled(a, b, c, d, e, f, g, h, i, j, k, l, m):
    n = j * j
    o = sum(n, axis=(2, ))
    tmp_broadcast = array(o) if isinstance(o, ndarray) or isscalar(o) else -1
    p = array(jax.numpy.broadcast_to(tmp_broadcast, (1, 2048, 1)))
    q = p / 256.0
    r = q + 9.999999747378752e-06
    s = jax.lax.rsqrt(r)
    t = j * s
    u = array(i).astype(float32)
    tmp_broadcast = array(u) if isinstance(u, ndarray) or isscalar(u) else -1
    v = array(jax.numpy.broadcast_to(tmp_broadcast, (1, 1, 256)))
    w = t * v
    x = tensordot(w, a, axes=((2, ), (1, )))
    y = x[0:1:][0:2048:][0:256:] if len(
        x.shape
    ) >= 3 else x  # static slice inputs:[(0, 0, 0), (1, 2048, 256), None]
    z = x[0:1:][0:2048:][256:512:] if len(
        x.shape
    ) >= 3 else x  # static slice inputs:[(0, 0, 256), (1, 2048, 512), None]
    aa = x[0:1:][0:2048:][512:768:] if len

Not very readable. If we're going to check in the converted JAX code, it will
probably take a lot of cleaning up. One might as well write JAX by hand or just
use `torchax` at runtime.

# Future ideas

- Can we intercept `jax`, `jax.numpy`, and `jax.lax` ops instead of JAX primitives?
  - Assumption: these ops are pure.
  - Benefit: improves readability.
- Can we inspect the stack to find better variable names?
  - Approach: find which variable in the stack at the model file corresponds to
    which `jaxpr` variable.
- Can we recreate the function structure?
  - Approach: figure out when the Python interpreter enters another function when
    running the model file.
  - Translate function by function.
- Can we use LLM to clean things up?
- Instead of generating JAX code from scratch, what if we created JAX comments
  next to each line in the PyTorch model file?

This will be a non-trivial project. But the foundation is laid.