# Experimenting converting PyTorch code to Jax code

Approach 1: let's print the jaxpr back into Python

In [1]:
%pip install --upgrade pip
%pip install --upgrade JaxDecompiler

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import jax
import torch
import torchax
import jax.extend.core as core
from JaxDecompiler import decompiler
from torchax.interop import jax_view

torchax.enable_globally()

In [3]:
# Example torch function
def my_mathy_thing(weight, hidden):
  activation = torch.matmul(weight, hidden)
  output = torch.sin(activation)
  return output + hidden

In [4]:
class NameIterator:
  """
  An iterator that yields strings in lexicographical order:
  'a', 'b', ..., 'z', 'aa', 'ab', ..., 'az', 'ba', ...

  The sequence starts with single letters, then moves to longer
  combinations, following the pattern similar to spreadsheet columns.
  """

  def __init__(self):
    self.alphabet = "abcdefghijklmnopqrstuvwxyz"
    self.current = None

  def __iter__(self):
    return self

  def __next__(self):
    if self.current is None:
      self.current = "a"
      return self.current

    # Convert the current string to a list of indices
    indices = [self.alphabet.index(c) for c in self.current]

    # Try to increment the rightmost index
    pos = len(indices) - 1
    indices[pos] += 1

    # Handle carry, like incrementing digits in a number
    while pos >= 0 and indices[pos] == 26:
      indices[pos] = 0
      pos -= 1

      if pos < 0:
        # We need to add a new position (like going from 'z' to 'aa')
        indices = [0] * (len(indices) + 1)
        break

      indices[pos] += 1

    # Convert back to a string
    self.current = "".join(self.alphabet[i] for i in indices)
    return self.current


def decompile_jaxpr(jaxpr: core.ClosedJaxpr) -> str:
  names = NameIterator()

  class NamedVar(core.Var):
    def __init__(self, suffix, aval):
      self.name = next(names)
      super().__init__(suffix, aval)

    def __str__(self):
      return self.name

    def __repr__(self):
      if self.suffix:
        return f"var_id_{id(self)}_{self.suffix}"
      return f"var_id_{id(self)}"

  def name(vars):
    for i in range(len(vars)):
      v = vars[i]
      if not isinstance(v, NamedVar) and isinstance(v, core.Var):
        vars[i] = NamedVar(v.suffix, v.aval)

  # Create an in-memory file object
  import io

  f = io.StringIO()
  # Name all the variables
  name(jaxpr.jaxpr.invars)
  for eqn in jaxpr.eqns:
      name(eqn.invars)
      name(eqn.outvars)
  name(jaxpr.jaxpr.outvars)
  python_lines = decompiler.decompiler(jaxpr)
  decompiler._recursively_write_python_program(f, python_lines)
  f.seek(0)
  s = f.read()
  f.close()
  return s

In [5]:
my_mathy_thing_jaxpr = jax.make_jaxpr(jax_view(my_mathy_thing))(jax.numpy.ones((4, 4)), jax.numpy.ones((4, 4)))
my_mathy_thing_py = decompile_jaxpr(my_mathy_thing_jaxpr)
print(my_mathy_thing_py)

import jax
from jax.numpy import *
from jax.experimental import sparse
from jax._src import prng
from mpi4py import MPI
def f(a, b):
    e = tensordot(c, d,axes=((1,), (0,)))
    g = sin(f)
    i = h * 1.0
    l = j + k
    return m

