In [1]:
import haiku as hk
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np

  PyTreeDef = type(jax.tree_structure(None))


In [31]:
class MixtureOfExperts(hk.Module):
    def __init__(self, num_experts, k, name=None):
        super().__init__(name=name)
        self._num_experts = num_experts
        self._k = k
        
    def route(self, x:jnp.ndarray):
        logits = hk.Linear(self._num_experts)(x)
        top_k_vals, top_k_idxs = jax.lax.top_k(logits, self._k)
        top_k_gates = jax.nn.softmax(top_k_vals, axis=-1)
        return top_k_gates, top_k_idxs

    def map(self, top_k_idxs:jnp.ndarray, x:jnp.ndarray) -> jnp.ndarray:
        total_counts, embedding_size = x.shape
        expert_idxs = jnp.arange(self._num_experts, dtype=np.int32)
        top_k_mask = expert_idxs[:, None, None] == top_k_idxs[None, :, :]
        routing_mask = top_k_mask.any(-1)
        sharded_x = [None] * self._num_experts
        idxs = [None] * self._num_experts
        for i in range(self._num_experts):
            idxs[i] = source_idx, slot_idx = jnp.where(top_k_mask[i])
            sharded_x[i] = x[source_idx]
        return idxs, sharded_x
    
    def apply_experts(self, sharded_y, experts):
        return jax.tree_util.tree_map(lambda f,x: f(x), experts, sharded_y)
    
    def gather(self, in_shape, idxs, sharded_y):
        out_buf = jnp.empty((in_shape, self._k, sharded_y[0].shape[-1]),
                            dtype=sharded_y[0].dtype)
        for i in range(self._num_experts):
            out_buf = out_buf.at[idxs[i]].set(sharded_y[i])
        return out_buf
    
    def reduce(self, in_shape, idxs, sharded_y, top_k_gates):
        out_buf = jnp.empty((in_shape, sharded_y[0].shape[-1]), dtype=sharded_y[0].dtype)
        for i in range(self._num_experts):
            source_idx, slot_idx = idxs[i]
            out_buf = out_buf.at[source_idx].add(top_k_gates[source_idx, slot_idx][:, None] * sharded_y[i])
        return out_buf
        
def moe(x):
    num_experts = 4
    experts = [hk.Linear(7) for i in range(num_experts)] 
    moe = MixtureOfExperts(num_experts, k=2)
    top_k_gates, top_k_idxs = moe.route(x)
    idxs, sharded_x = moe.map(top_k_idxs, x)
    sharded_outs = moe.apply_experts(sharded_x, experts)
    return moe.reduce(x.shape[0], idxs, sharded_outs, top_k_gates)

forward_moe = hk.transform(moe)

In [32]:
from pprint import pprint
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (15, 20))
# print(x)
rng_key = jax.random.PRNGKey(42)
params = forward_moe.init(rng=rng_key, x=x)
out_buf = forward_moe.apply(params=params, x=x, rng=rng_key)
pprint(out_buf)
print(out_buf.shape)

DeviceArray([[ 1.23953128e+00, -3.22048664e-02,  1.61771923e-01,
              -1.40354782e-03,  1.21511295e-01, -1.81599051e-01,
              -6.96900785e-02],
             [ 5.48718631e-01,  1.28376633e-01, -5.39104223e-01,
              -6.70641184e-01,  1.09051979e+00, -1.28412640e+00,
              -1.14772761e+00],
             [ 8.41994166e-01, -1.02324653e+00,  3.11477125e-01,
              -4.01870668e-01, -1.49021864e-01,  5.34370542e-01,
               3.68867695e-01],
             [ 6.75134182e-01, -5.36827803e-01, -5.43348968e-01,
               3.53505313e-01,  3.08738649e-01, -6.90067232e-01,
              -6.90422177e-01],
             [-2.37773836e-01, -6.79171264e-01,  1.47294909e-01,
               7.21722841e-04,  5.62279522e-01,  3.10913086e-01,
              -6.09521687e-01],
             [ 2.00517774e-01, -9.62487936e-01,  1.20790944e-01,
               4.39302444e-01,  2.74160624e-01,  3.32039446e-01,
              -2.17610329e-01],
             [ 2.46761292e-0