In [5]:
import haiku as hk
import jax
import optax
from jax import random
from jax import numpy as jnp
from jax import jit

In [6]:
class MLP_test(hk.Module):
    def __init__(self, dimensions, name=None):
        super().__init__(name=name)
        self.dimensions = dimensions
        self.mlp = hk.nets.MLP(output_sizes=dimensions, name='mlp')

    def __call__(self, x):
        return self.mlp(x)
    
    
    
def _mlp_forward(x):
    module = MLP_test([1,2,3])
    return module(x)

init, apply = hk.transform(_mlp_forward)


In [58]:
class conv_test(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)    
        
        #TODO: add state for batch norm 
        
        self.conv1 = hk.Conv2D(32, 2) #out_channels, kernel_size (stride defaults to 1, in_channels done automatically) 
        self.bn1 = hk.BatchNorm(False, False, 0.995)
        self.conv2 = hk.Conv2D(64, 2)
        self.bn2 = hk.BatchNorm(False, False, 0.995)
        self.fc = hk.Linear(128)
        self.head = hk.Linear(1)

    def __call__(self, x, is_training=True):
        x = self.conv1(x)
        x = jax.nn.relu(self.bn1(x, is_training))
        x = jax.nn.relu(self.bn2(self.conv2(x), is_training))
        x = jax.nn.relu(self.fc(x)) #check correct when compared to torch version
        x = jax.nn.sigmoid(self.head(x))
        
        return x
    
    
    
def _conv_forward(x):
    module = conv_test()
    return module(x)

init, apply = hk.transform_with_state(_mlp_forward)


In [59]:
rng_key = random.PRNGKey(0)
x = random.normal(rng_key, (1,6,128))
initial_params, state = init(rng_key,x)


%timeit new_params, new_state = apply(initial_params, state, rng_key,x)
%timeit new_params_jit, new_state = jit(apply)(initial_params,state, rng_key,x)

5.31 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
374 µs ± 1.94 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
rng_key = random.PRNGKey(0)
x = jnp.array([1.,2.,3.,4.])
initial_params = init(rng_key,x)

%timeit new_params = apply(initial_params,rng_key,x)
%timeit new_params_jit = jit(apply)(initial_params,rng_key,x)




1.11 ms ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
168 µs ± 512 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [51]:
print(hk.experimental.tabulate(_conv_forward)(x))

+----------------------------------------------------------------+----------------------------------------------------------------------+--------------------+------------------+--------------+---------------+---------------+
| Module                                                         | Config                                                               | Module params      | Input            | Output       |   Param count |   Param bytes |
| conv_test (conv_test)                                          | conv_test()                                                          |                    | f32[1,6,128]     | f32[1,6,1]   |        33,121 |     132.48 KB |
+----------------------------------------------------------------+----------------------------------------------------------------------+--------------------+------------------+--------------+---------------+---------------+
| conv_test/~/conv2_d (Conv2D)                                   | Conv2D(output_channels=32, kernel

In [8]:
print(hk.experimental.tabulate(_mlp_forward)(x))

+------------------------------------+-----------------------------------------+-----------------+---------+----------+---------------+---------------+
| Module                             | Config                                  | Module params   | Input   | Output   |   Param count |   Param bytes |
| mlp_test (MLP_test)                | MLP_test(dimensions=[1, 2, 3])          |                 | f32[4]  | f32[3]   |            18 |       72.00 B |
+------------------------------------+-----------------------------------------+-----------------+---------+----------+---------------+---------------+
| mlp_test/~/mlp (MLP)               | MLP(output_sizes=[1, 2, 3], name='mlp') |                 | f32[4]  | f32[3]   |            18 |       72.00 B |
|  └ mlp_test (MLP_test)             |                                         |                 |         |          |               |               |
+------------------------------------+-----------------------------------------+--------