In [3]:
import haiku as hk
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
from typing import List, Callable
from pprint import pprint

  PyTreeDef = type(jax.tree_structure(None))


In [84]:
class MixtureOfExperts(hk.Module):
    def __init__(self, num_experts, k, capacity, name=None):
        super().__init__(name=name)
        self._num_experts = num_experts
        self._k = k
        self._expert_idxs = jnp.arange(self._num_experts, dtype=np.int32)
        self._capacity = capacity
        
    def route(self, x:jnp.ndarray):
        logits = hk.Linear(self._num_experts)(x)
        top_k_vals, top_k_idxs = jax.lax.top_k(logits, self._k)
        top_k_gates = jax.nn.softmax(top_k_vals, axis=-1)
        # buffer_idxs
        # I x K x E: for each item, which of the top k is going to which expert,
        #            value is the i-th item of each expert
        expert_mask = top_k_idxs.transpose((1, 0)).flatten()[:, None] == self._expert_idxs
        # K * I x E
        buffer_idxs = jnp.cumsum(expert_mask, axis=0) * expert_mask - 1
        # K * I x E
        buffer_idxs = buffer_idxs.reshape(self._k, -1, self._num_experts)
        # K x I x E
        
        # dispatch_idxs : E x I
        dispatch_idxs = buffer_idxs.transpose((2, 0, 1)).max(axis=1)
        # combine_idxs : I x K
        combine_idxs = buffer_idxs.max(axis=-1).transpose(1, 0)
        return top_k_gates, top_k_idxs, dispatch_idxs, combine_idxs
    
    def dispatch(self, dispatch_idxs:jnp.ndarray, x:jnp.ndarray) -> jnp.ndarray:
        def expert_array(exp_buf_idxs):
            zeros = jnp.zeros((self._capacity + 1, x.shape[-1]))
            return zeros.at[exp_buf_idxs].set(x)[:-1]
        data = jax.vmap(expert_array)(dispatch_idxs)
        return data
    
    def combine(self, expert_idxs, buffer_idxs, data):
        # expert_idxs: I x K
        # buffer_idxs: I x K
        # data: E x C x D
        data = jnp.pad(data, ((0, 0), (0, 1), (0, 0)), 'constant')
        return data.at[expert_idxs, buffer_idxs].get()
    
    def wrap(self, cls, *args, **kwargs):
        functions = [cls(*args, **kwargs, name='expert_%d' % i)
                     for i in range(self._num_experts)]
        # nested transforms for all 8 + extract inits + applies
        init_applies = jax.tree_util.tree_map(lambda fun: hk.transform(fun), functions)  # nested transform
        inits = [x.init for x in init_applies]
        apply = init_applies[0].apply 
        
        def moe_fun(x):
            # get lifted parameters for each expert
            params = [hk.lift(init, name="inner")(hk.next_rng_key(), x) for init in inits]
            param_name = next(iter(params[0].keys()))
            just_params = [next(iter(p.values())) for p in params]
            stacked_params = jax.tree_util.tree_map(lambda *arrays: jnp.stack(arrays), *just_params)
            def apply_fn(params, x):
                return apply(params, hk.next_rng_key(), x)
            return jax.vmap(apply_fn)({param_name: stacked_params}, x)
        
        return moe_fun

            
def moe(x):
    num_experts = 4
    moe = MixtureOfExperts(num_experts=num_experts, capacity=10, k=2)
    top_k_gates, top_k_idxs, dispatch_idxs, combine_idxs = moe.route(x)
    z = x * 0 + jnp.arange(x.shape[0])[:, None]
    dispatched_x = moe.dispatch(dispatch_idxs, z)
    dispatched_y = moe.wrap(hk.Linear, 3)(dispatched_x)
    z_ = moe.combine(top_k_idxs, combine_idxs, dispatched_y)
    return z_

forward_moe = hk.transform(moe)

In [85]:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (13, 3))
# print(x)
rng_key = jax.random.PRNGKey(42)
params = forward_moe.init(rng=rng_key, x=x)
fn_apply = jax.jit(forward_moe.apply)
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf)

[[[  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[ -1.2483661    0.64470327   0.79231584]
  [ -0.7740661   -0.01859003   0.2117182 ]]

 [[ -1.5481322   -0.03718007   0.4234364 ]
  [  1.1794753    2.540656     0.15744385]]

 [[  1.7692131    3.8109846    0.2361658 ]
  [ -3.741686     0.32243448  -1.4485036 ]]

 [[  2.3589506    5.081312     0.3148877 ]
  [ -4.9934645    2.578813     3.1692634 ]]

 [[ -6.241831     3.223516     3.961579  ]
  [ -6.2361436    0.53739077  -2.4141726 ]]

 [[  3.5384262    7.621969     0.4723316 ]
  [ -7.490197     3.8682194    4.753895  ]]

 [[ -8.738564     4.512923     5.5462112 ]
  [ -5.4184628   -0.13013017   1.4820267 ]]

 [[  4.717901    10.162624     0.6297754 ]
  [ -9.977829     0.85982525  -3.8626761 ]]

 [[-11.235295     5.802329     7.130842  ]
  [ -6.966595    -0.16731024   1.9054636 ]]

 [[-12.483662     6.447032     7.923158  ]
  [ -7.740661    -0.18590021   2.1171815 ]]

 [[-13.732028     7.0917354    8

In [12]:
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x[:-1], rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x[:-2], rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x[:-3], rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x[:-3], rng=rng_key)
print(out_buf.shape)
x = jax.random.normal(key, (200, 20))
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x[:-5], rng=rng_key)
print(out_buf.shape)

(13, 2, 4)
(12, 2, 4)
(11, 2, 4)
(10, 2, 4)
(13, 2, 4)
(10, 2, 4)
(200, 2, 4)
(195, 2, 4)


In [None]:
from pprint import pprint

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2000, 20))
# print(x)
rng_key = jax.random.PRNGKey(42)
params = forward_moe.init(rng=rng_key, x=x)
fn_apply = jax.jit(forward_moe.apply)
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf.shape)
out_buf = fn_apply(params=params, x=x, rng=rng_key)
print(out_buf.shape)