In [1]:
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as jopt
import haiku as hk
import mnist
import numpy as np
import math
import datetime

jax.config.update("jax_debug_nans", False)
float_type = jnp.float32

In [18]:
class Base(hk.Module):

  def __init__(self, j, k, name=None):
    super().__init__(name=name)
    self.j = j
    self.k = k

  def __call__(self, x):
    j = self.j
    k = self.k
    w_init = hk.initializers.TruncatedNormal(1.0 / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=float_type, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=float_type, init=jnp.ones)
    return jax.nn.relu(jnp.dot(x, w) + b) + x

In [19]:
@hk.transform
def model(x):
    return Base(2, 2, "MyLayer")(x)

In [20]:
x = jnp.array(jnp.array([[0,1]]))
params = model.init(rng=jax.random.PRNGKey(0), x=x)

In [21]:
model.apply(params=params, x=x, rng=jax.random.PRNGKey(0))

Array([[-0.39705074,  0.530758  ]], dtype=float32)

In [22]:
params

{'MyLayer': {'w': Array([[ 0.74966276, -0.09669809],
         [-0.39705074,  0.530758  ]], dtype=float32),
  'b': Array([1., 1.], dtype=float32)}}