# JAX Graph Transformation

The code below implements an inverse transform of the function.

In mathematics, the inverse function of a function $f$ (also called the inverse of $f$) is a function that undoes the operation of $f$.

As an example, consider the real-valued function of a real variable given by $(x) = 5x − 7$. One can think of $f$ as the function which multiplies its input by $5$ then subtracts $7$ from the result. To undo this, one adds $7$ to the input, then divides the result by $5$. Therefore, the inverse of $f$ is the function
$f^{-1}\colon \mathbb {R} \to \mathbb {R} $ defined by
then divides the result by $5$. Therefore, the inverse of f is the function
$$f^{-1}(y)=\frac{y+7}{5}$$

Also see:

- https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
- Jax installation on Apple-Silicon chips: https://developer.apple.com/metal/jax/

In [17]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from jax import random

x = random.normal(random.PRNGKey(0), (5000, 5000)) # deterministic RNG, 5000x5000 matrix

def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b) # tan((x * w) + b)

fast_f = jit(f)

In [18]:
def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))

foo
=====
invars: [a]
outvars: [b]
constvars: []
equation: [a, 1] add [b] {}

jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

bar
=====
invars: [a, b, c]
outvars: [g, c]
constvars: []
equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [d, b] add [e] {}
equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [e, f] add [g] {}

jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
    g:f32[5] = add e f
  in (g, c) }


In [19]:
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps

from jax import core
from jax import lax # convert numpy primitives to be compatible with jax
from jax._src.util import safe_map # assert: arg len must stay consistent

def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)

{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]


In [20]:
# def safe_map_temp(func, l1, l2):
#   assert len(l1) == len(l2)
#   zipped = zip(l1, l2)
#   for el1, el2 in zipped:
#     func(el1, el2)

# # wrapper / interpreter (because we can't call the function directly)
# def eval_jaxpr(jaxpr, consts, *args):
#   # map
#   env = {}
# 
#   def read(var):
#     return env[var]
# 
#   def write(var, val):
#     env[var] = val
# 
#   safe_map_temp(write, jaxpr.invars, args) # map key: arguments, map val: args ---> assign value to each argument variable
#   for eqn in jaxpr.eqns: # for each instruction in lambda
#     invals = safe_map_temp(read, eqn.invars) # read all arguments
#     outvals = eqn.primitive.bind(*invals, **eqn.params) # assign value to each input argument variable AND PROCESS (primitives are instructions)
#     if not eqn.primitive.multiple_results: # single result
#       outvals = [outvals]
#     safe_map_temp(write, eqn.outvars, outvals) # tuple result
#   return safe_map_temp(read, jaxpr.outvars)

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}
  
  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)  
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results: 
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals) 
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars) 

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))

[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

In [21]:
inverse_registry = {}
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}

  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)

  for eqn in jaxpr.eqns[::-1]: # read in reverse
    invals = safe_map(read, eqn.outvars) # replace invars with outvars

    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    outval = inverse_registry[eqn.primitive](*invals)

    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars)

def f(x):
  return jnp.exp(jnp.tanh(x)) # exp(tan(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0) # f^-1(f(1)) == 1

jax.make_jaxpr(inverse(f))(f(1.))

{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }

In [22]:
jit(inverse(f))((jnp.arange(5) + 1.) / 5.)

Array([        nan, -1.5653983 , -0.56384623, -0.22696194,  0.        ],      dtype=float32, weak_type=True)