In [35]:
import jax
import numpy as np
from jax import numpy as jnp
from functools import partial
from jax import tree_util


**jax.jit annotation on class method pitfall**

In [25]:
class MultiCal:
  def __init__(self, x:jnp.ndarray, isMul: bool):
    self.x = x
    self.isMul = isMul

  @jax.jit
  def optMul(self, y):
    if self.isMul:
      return self.x * y
    else:
      return y

mc = MultiCal(jnp.zeros((3,3), dtype=jnp.float32), True)
print(mc)
# The first arg type is self and jit doesn't know how to handle it.
# print(mc.optMul(10))

<__main__.MultiCal object at 0x7ea8fc1091f0>


**Strategy 1: Use jit-annotated helper functions**:

In [15]:
class MultiCal:
  def __init__(self, x:jnp.ndarray, isMul: bool):
    self.x = x
    self.isMul = isMul

  def optMul(self, y):
    return _optMul(self.x, self.isMul, y)

@partial(jax.jit, static_argnums=1)
def _optMul(x, isMul, y):
  if isMul:
    return x*y
  else:
    return x

mc = MultiCal(jnp.ones((3,3), dtype=jnp.float32), True)
print(mc.optMul(10))

[[10. 10. 10.]
 [10. 10. 10.]
 [10. 10. 10.]]


**Strategy 2a: Marking self argument as static - without hash and eq change**

In [31]:
class MultiCal:
  def __init__(self, x:jnp.ndarray, isMul: bool):
    self.x = x
    self.isMul = isMul

  @partial(jax.jit, static_argnums=0)
  def optMul(self, y):
    if self.isMul:
      return self.x * y
    else:
      return y

mc = MultiCal(jnp.full((3,3), 2.0), True)
print(mc.optMul(10))
mc.isMul = False
print(mc.optMul(4.0))
mc.isMul = True
# The following output is wrong because the cached compilation is used.
print(mc.optMul(8.0))

[[20. 20. 20.]
 [20. 20. 20.]
 [20. 20. 20.]]
4.0
8.0


**Strategy 2b: Marking self argument a static - with hash and eq change**

In [34]:
class MultiCal:
  def __init__(self, x:jnp.ndarray, isMul: bool):
    self.x = x
    self.isMul = isMul

  @partial(jax.jit, static_argnums=0)
  def optMul(self, y):
    if self.isMul:
      return self.x * y
    else:
      return y

  def __hash__(self):
    # We can't hash jax array so I used id() of the array to compute hash.
    return hash((id(self.x), self.isMul))

  def __eq__(self, other):
    if not isinstance(other, MultiCal):
      return False
    return (self.x, self.isMul) == (other.x, other.isMul)

mc = MultiCal(jnp.full((3,3), 2.0), True)
print(mc.optMul(10))
mc.isMul = False
print(mc.optMul(4.0))
mc.isMul = True
print(mc.optMul(8.0))

[[20. 20. 20.]
 [20. 20. 20.]
 [20. 20. 20.]]
4.0
[[16. 16. 16.]
 [16. 16. 16.]
 [16. 16. 16.]]


**Strategy 3: Use PyTree**

In [37]:
class MultiCal:
  def __init__(self, x:jnp.ndarray, isMul: bool):
    self.x = x
    self.isMul = isMul

  def optMul(self, y):
    if self.isMul:
      return self.x * y
    else:
      return y

  def _tree_flatten(self):
    dynamic_args = (self.x,)
    static_args = {'isMul': self.isMul}
    return (dynamic_args, static_args)

  @classmethod
  def _tree_unflatten(cls, dynamic_args, static_args):
    # dynamic_args uses Positional Unpacking; static_args uses Keyword Unpacking
    return cls(*dynamic_args, **static_args)

tree_util.register_pytree_node(MultiCal,
                               MultiCal._tree_flatten,
                               MultiCal._tree_unflatten)
mc = MultiCal(jnp.full((3,3), 2.0), True)
print(mc.optMul(10))
mc.isMul = False
print(mc.optMul(4.0))
mc.isMul = True
print(mc.optMul(8.0))



[[20. 20. 20.]
 [20. 20. 20.]
 [20. 20. 20.]]
4.0
[[16. 16. 16.]
 [16. 16. 16.]
 [16. 16. 16.]]
