In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input

tf.enable_eager_execution()

In [7]:
# 4 input features 3 num experts 2 num tasks
expert_kernels = tf.constant([
    [[1., 1., 1.], [2., 2., 1.]], \
    [[0.1, 0.5, 1.], [0.4, 0.1, 1.]], \
    [[1., 1., 1.], [2., 2., 1.]], \
    [[0., 1., 6.], [0., 2., 0.]]
    ], dtype=tf.float64)
expert_kernels

<tf.Tensor: id=2, shape=(4, 2, 3), dtype=float64, numpy=
array([[[1. , 1. , 1. ],
        [2. , 2. , 1. ]],

       [[0.1, 0.5, 1. ],
        [0.4, 0.1, 1. ]],

       [[1. , 1. , 1. ],
        [2. , 2. , 1. ]],

       [[0. , 1. , 6. ],
        [0. , 2. , 0. ]]])>

In [9]:
gate_kernels= [tf.constant([[0.1, 0.5, 1.], [0.4, 0.1, 1.], [1., 1., 1.], [2., 2., 1.]], dtype=tf.float64), tf.constant([[1., 2., 1.], [4., 0.2, 1.5], [2., 1., 0.], [5., 2., 1.]], dtype=tf.float64)]
gate_kernels

[<tf.Tensor: id=3, shape=(4, 3), dtype=float64, numpy=
 array([[0.1, 0.5, 1. ],
        [0.4, 0.1, 1. ],
        [1. , 1. , 1. ],
        [2. , 2. , 1. ]])>,
 <tf.Tensor: id=4, shape=(4, 3), dtype=float64, numpy=
 array([[1. , 2. , 1. ],
        [4. , 0.2, 1.5],
        [2. , 1. , 0. ],
        [5. , 2. , 1. ]])>]

In [10]:
inputs = tf.constant([[1., 2., 1., 0.], [4., 0.2, 1., 1.]], dtype=tf.float64)
inputs

<tf.Tensor: id=5, shape=(2, 4), dtype=float64, numpy=
array([[1. , 2. , 1. , 0. ],
       [4. , 0.2, 1. , 1. ]])>

In [11]:
expert_outputs = tf.tensordot(a=inputs, b=expert_kernels, axes=1)
expert_outputs # 2 sample, 2 tasks, 3 experts

<tf.Tensor: id=16, shape=(2, 2, 3), dtype=float64, numpy=
array([[[ 2.2 ,  3.  ,  4.  ],
        [ 4.8 ,  4.2 ,  4.  ]],

       [[ 5.02,  6.1 , 11.2 ],
        [10.08, 12.02,  5.2 ]]])>

In [13]:
import tensorflow.keras.backend as K

In [14]:
gate_outputs = []

for index, gate_kernel in enumerate(gate_kernels):
    gate_output = K.dot(x=inputs, y=gate_kernel)
    gate_outputs.append(gate_output)
gate_outputs = tf.nn.softmax(gate_outputs)
gate_outputs

<tf.Tensor: id=20, shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.00151222e-01, 8.19968851e-02, 8.17851893e-01],
        [4.79733364e-02, 2.23775958e-01, 7.28250705e-01]],

       [[9.98589658e-01, 4.99745626e-04, 9.10595901e-04],
        [6.80656487e-01, 3.18320187e-01, 1.02332564e-03]]])>

In [15]:
final_outputs = []
hidden_units = 2

for gate_output in gate_outputs:
    expanded_gate_output = K.expand_dims(gate_output, axis=1)
    weighted_expert_output = expert_outputs * K.repeat_elements(expanded_gate_output, hidden_units, axis=1)
    final_outputs.append(K.sum(weighted_expert_output, axis=2))
final_outputs

[<tf.Tensor: id=34, shape=(2, 2), dtype=float64, numpy=
 array([[3.73773092, 4.09652035],
        [9.76226739, 6.96026192]])>,
 <tf.Tensor: id=48, shape=(2, 2), dtype=float64, numpy=
 array([[ 2.20203887,  4.79897168],
        [ 5.37010995, 10.69254733]])>]