# API demo for pax (lightweight)

In [1]:
import sys
sys.path.append('../src')

%load_ext autoreload
%autoreload 2
%aimport pax

from pax import Dense
import jax
import jax.numpy as jnp
import pandas as pd

from io import BytesIO
from PIL import Image

In [4]:
df = pd.read_parquet('../data/mnist/mnist_train.parquet')
df

Unnamed: 0,image,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,0
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,4
3,"{'bytes': b""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...",1
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,9
...,...,...
59995,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,8
59996,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,3
59997,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
59998,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,6


In [5]:
def bytes_dict_to_jax_array(d):
    img = Image.open(BytesIO(d['bytes']))
    return jnp.array(img)

def preprocess_df(df):
    df = df.copy()
    df['image'] = df['image'].map(bytes_dict_to_jax_array)
    X, y = jnp.stack(df['image'].tolist()), jax.nn.one_hot(df['label'], 10)
    X = X.reshape(X.shape[0], -1) / 255
    return X, y

X_train, y_train = preprocess_df(df)
X_train.shape, y_train.shape

((60000, 784), (60000, 10))

In [6]:
# layer1 = Dense(in_nodes, out_nodes) # or whatever
# layer2 = BatchNorm() # or whatever
# ...
# layerN = ...
# layers = [layer1, ..., layerN]

# params = [layer1.weights(), layer2.weights(), ...., layerN.weights()]
# state = [layer1.state(), layer2.state(), ...., layerN.state()]
# layer_funcs = [layer1.func(), layer2.func(), ..., layerN.func()]

# @partial(jax.jit, static_argnames="is_training")
# def apply(params, state, X, *, is_training, key):
#   for i in range(N):
#     cur_key, key = jax.random.split(key)
#     X, state[i] = layer_funcs[i](params[i], state[i], X, is_training=is_training, key=cur_key)
#   return X, state

In [7]:
shapes = [X_train.shape[1], 128, 64, 10]

layers = []
for i in range(len(shapes)-2):
    layers.append(Dense(shapes[i], shapes[i+1], activation=jax.nn.relu))
layers.append(Dense(shapes[-2], shapes[-1]))

layers


[Dense(in_nodes=784, out_nodes=128, activation=<jax._src.custom_derivatives.custom_jvp object at 0x718553ceb050>),
 Dense(in_nodes=128, out_nodes=64, activation=<jax._src.custom_derivatives.custom_jvp object at 0x718553ceb050>),
 Dense(in_nodes=64, out_nodes=10, activation=<function Dense.__post_init__.<locals>.<lambda> at 0x718546c5af20>)]

In [8]:
params = tuple(l.weights(jax.random.key(0)) for l in layers)
state = tuple(l.state() for l in layers)

jax.tree.structure(params)

PyTreeDef(({'b': *, 'w': *}, {'b': *, 'w': *}, {'b': *, 'w': *}))

In [9]:
@jax.jit
def fwd(params, state, X, *, is_training, key):
    for i, layer in enumerate(layers):
        key, subkey = jax.random.split(key)
        print(jax.tree.structure(params[i]))
        X, state_i = layer.func()(params[i], state[i], X, is_training=is_training, key=subkey)
        state = state[:i] + (state_i,) + state[i+1:]  # update tuple immutably
    return X, state

fwd(params, state, X_train, is_training=False, key=jax.random.key(1))


PyTreeDef({'b': *, 'w': *})
PyTreeDef({'b': *, 'w': *})
PyTreeDef({'b': *, 'w': *})


(Array([[-0.15497945, -0.4568455 , -0.39541277, ..., -0.13534677,
          0.32761425, -0.19820786],
        [-0.05561724, -0.49285766, -0.10636605, ..., -0.24057765,
          0.07059948, -0.31488508],
        [ 0.12788975, -0.14176412, -0.0644017 , ..., -0.2389179 ,
         -0.296124  ,  0.12700762],
        ...,
        [ 0.13261902, -0.45547187, -0.2919717 , ..., -0.10995544,
          0.2916181 , -0.04156661],
        [ 0.17468776, -0.6099087 , -0.51634234, ...,  0.0476726 ,
         -0.05541314, -0.24837567],
        [ 0.103691  , -0.23499839, -0.41506302, ..., -0.2683155 ,
         -0.00104973, -0.01717444]], dtype=float32),
 ({}, {}, {}))

In [10]:
def loss_fn(params, state, X, y, *, key):
    logits, new_state = fwd(params, state, X, is_training=True, key=key)
    loss = -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=1))
    return loss, new_state

(loss_val, new_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state, X_train, y_train, key=jax.random.key(0))
# loss_val, new_state, grads

In [11]:
# grad_loss = jax.jit(jax.grad(loss))

@jax.jit
def train(params, state, X_train, y_train, *, key, num_iters=30000):
    lr = 0.001
    keys = jax.random.split(key, num_iters)

    def body(i, carry):
        (params, state) = carry
        (loss_val, new_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state, X_train, y_train, key=keys[i])
        params = jax.tree.map(lambda param, grad: param - lr * grad, params, grads)
        
        def do_print(_):
            jax.debug.print("step {i}, loss: {l}", i=i, l=loss_val)
            return None

        _ = jax.lax.cond(i % 100 == 0, do_print, lambda _: None, operand=None)
        return params, new_state

    params, final_state = jax.lax.fori_loop(0, num_iters, body, (params, state))
    return params

params = train(params, state, X_train, y_train, key=jax.random.key(0))

step 0, loss: 2.317497968673706
step 100, loss: 2.2239949703216553
step 200, loss: 2.140049695968628
step 300, loss: 2.0588788986206055
step 400, loss: 1.9777990579605103
step 500, loss: 1.8962042331695557
step 600, loss: 1.814244031906128
step 700, loss: 1.7326817512512207
step 800, loss: 1.65241277217865
step 900, loss: 1.574188470840454
step 1000, loss: 1.4984387159347534
step 1100, loss: 1.4256908893585205
step 1200, loss: 1.3563148975372314
step 1300, loss: 1.2907150983810425
step 1400, loss: 1.2293421030044556
step 1500, loss: 1.1722177267074585
step 1600, loss: 1.119492769241333
step 1700, loss: 1.0709500312805176
step 1800, loss: 1.0264092683792114
step 1900, loss: 0.9856464862823486
step 2000, loss: 0.9482986927032471
step 2100, loss: 0.9141151309013367
step 2200, loss: 0.8827757835388184
step 2300, loss: 0.8539860248565674
step 2400, loss: 0.8275564312934875
step 2500, loss: 0.8031908869743347
step 2600, loss: 0.7807102799415588
step 2700, loss: 0.7599031925201416
step 2800, 