<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_jax_static.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Some examples refer to JAX's documentation under the Apache Lincense 2.0.

google/jax is licensed under the<br>
Apache License 2.0<br>
A permissive license whose main conditions require preservation of copyright and license notices. <br>
Contributors provide an express grant of patent rights.<br>
Licensed works, modifications, and larger works may be distributed under different terms and without source code.


# 1. Load libarary

In [1]:
# Standard modules.
import sys

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

# Enable float64.
jax.config.update("jax_enable_x64", True)

In [2]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"JAX:{jax.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
JAX:0.4.16


# 2. Examples

In [3]:
def example_fun1(length, val):
    return jnp.ones([length,]) * val

In [4]:
print(example_fun1(10, 4))
print(example_fun1(5, 4))



[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


In [5]:
example_jit1 = jit(example_fun1)
# This will fail.
try:
    print(example_jit1(10, 4))
except Exception as inst:
    print(inst)

Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun1 at <ipython-input-3-afa096f8a3a6>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.


In [6]:
example_jit1 = jit(example_fun1, static_argnums=(0,))
print(example_jit1(10, 4))
print(example_jit1(5, 4))

[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


In [7]:
def example_fun2(array, val):
    return jnp.ones([array.shape[0],]) * val

In [8]:
arr = jnp.ones([10])
print(example_fun2(arr, 4))
print(example_fun2(arr[:5], 4))

[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


In [9]:
def example_fun3(array, val):
    length = array.sum().astype(jnp.int32)
    return jnp.ones([length,]) * val

In [10]:
arr = jnp.ones([10])
print(arr)
print(example_fun3(arr, 4))

arr = arr.at[5:].set(0)
print(arr)
print(example_fun3(arr, 4))

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
[4. 4. 4. 4. 4.]


In [11]:
arr = jnp.ones([10])
example_jit3 = jit(example_fun3)
# This will fail.
try:
    print(example_jit3(arr, 4))
except Exception as inst:
    print(inst)

Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun3 at <ipython-input-9-00f5511bb8ac>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument array.


In [12]:
example_jit3 = jit(example_fun3, static_argnums=(0,))
arr = jnp.ones([10])
print(arr)
# This will fail.
try:
    print(example_jit3(arr, 4))
except Exception as inst:
    print(inst)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Non-hashable static arguments are not supported. An error occurred during a call to 'example_fun3' while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]. The error was:
TypeError: unhashable type: 'ArrayImpl'



In [13]:
from typing import Generic, TypeVar
from functools import partial

T = TypeVar('T')      # Declare type variable

# Workaround to avoid unhashable error.
# https://github.com/google/jax/issues/4572
class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

In [14]:
example_jit3 = jit(example_fun3, static_argnums=(0,))
arr = jnp.ones([10])
print(arr)
# This will fail.
try:
    print(example_jit3(HashableArrayWrapper(arr), 4))
except Exception as inst:
    print(inst)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun3 at <ipython-input-9-00f5511bb8ac>:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f64[] = reduce_sum[axes=(0,)] b
    from line <ipython-input-9-00f5511bb8ac>:2 (example_fun3)


In [15]:
example_jit3 = jit(example_fun3, static_argnums=(0,))
arr = np.ones([10])
print(arr)
print(example_jit3(HashableArrayWrapper(arr), 4))
arr[5:] = 0
print(example_jit3(HashableArrayWrapper(arr), 4))

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


In [16]:
@jit
def check(arr):
    print(f"x = {arr}")
    print(f"x.shape = {arr.shape}")
    print(f"x.sum() = {arr.sum()}")
    return arr.sum()

print("Input jnp.array")
arr = jnp.ones([10])
print(check(arr))
print("Input jnp.array with same shape")
arr = jnp.ones([10])
print(check(arr))
print("Input jnp.array with different shape")
arr = jnp.ones([5])
print(check(arr))
print("Input np.array with different shape")
arr = np.ones([10])
print(check(arr))

Input jnp.array
x = Traced<ShapedArray(float64[10])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (10,)
x.sum() = Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=1/0)>
10.0
Input jnp.array with same shape
10.0
Input jnp.array with different shape
x = Traced<ShapedArray(float64[5])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (5,)
x.sum() = Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=1/0)>
5.0
Input np.array with different shape
10.0
