In [10]:
from collections import namedtuple
from functools import partial
import os
import time
import requests

import jax.random as random
import jax.numpy as jnp
from jax import jacfwd, jit, grad

1. Jiting a function with static arguments

In [16]:
def g(x, n):
    for i in range(n):
        x = x ** 2
    return x

jit(g, static_argnums=1)(jnp.arange(4), 3)
#jit(g, static_argnames='n')(jnp.arange(4), 3)

Array([   0,    1,  256, 6561], dtype=int32)

In [17]:
@partial(jit, static_argnums=1)
#@partial(jit, static_argnames=['n'])
def g(x, n):
    for i in range(n):
        x = x ** 2
    return x

g(jnp.arange(4), 3)

Array([   0,    1,  256, 6561], dtype=int32)

## evaluating a composition of functions iteratively

In [4]:
def apply_activation(x):
    return jnp.maximum(0.0, x)

def get_dot_product(u, v):
    return jnp.dot(u, v)

In [5]:
# generate key
key = random.PRNGKey(1)
n, m, p = (1000, 10000, 1000)
u = random.normal(key=key, shape=[n, m], dtype=jnp.float32)

# generate subkey
key, subkey = random.split(key)
v = random.normal(key=subkey, shape=[m, p], dtype=jnp.float32)

# JIT the functions we have
dot_product_jit  = jit(get_dot_product)
activation_jit = jit(apply_activation)

for i in range(10):
    
    start = time.time()
    z = dot_product_jit(u, v).block_until_ready()
    dt_dot_product = time.time() - start
    
    start = time.time()
    a = activation_jit(z).block_until_ready()
    dt_activation = time.time() - start
    
    print(f'it.: {i+1}, dot product time: {dt_dot_product:.1e}, activation time: {dt_activation:.1e}')

it.: 1, dot product time: 5.3e-02, activation time: 1.3e-02
it.: 2, dot product time: 3.9e-02, activation time: 2.7e-04
it.: 3, dot product time: 3.8e-02, activation time: 2.7e-04
it.: 4, dot product time: 3.8e-02, activation time: 3.1e-04
it.: 5, dot product time: 3.9e-02, activation time: 4.5e-04
it.: 6, dot product time: 3.9e-02, activation time: 5.6e-04
it.: 7, dot product time: 4.0e-02, activation time: 2.8e-04
it.: 8, dot product time: 3.9e-02, activation time: 2.4e-04
it.: 9, dot product time: 4.0e-02, activation time: 2.6e-04
it.: 10, dot product time: 3.9e-02, activation time: 3.7e-04


## jit a method inside a class

In [6]:
class FeedForwardNN:
    def __init__(self):
        pass
    
    def get_dot_product(self, u, v):
        return jnp.dot(u, v)
    
    def apply_activation(self, x):
        return jnp.maximum(0.0, x)

In [7]:
# generate key
key = random.PRNGKey(1)
n, m, p = (10, 100, 50)
u = random.normal(key=key, shape=[n, m], dtype=jnp.float32)

# generate subkey
key, subkey = random.split(key)
v = random.normal(key=subkey, shape=[m, p], dtype=jnp.float32)

# init nn
nn = FeedForwardNN()

In [8]:
# forward pass
x = nn.get_dot_product(u, v)
x = nn.apply_activation(x)