# MoE in Tensorflow

In [None]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [None]:
import numpy as np
import tensorflow as tf

## Multi-task

In [None]:
def task_network(inputs,
                 hidden_units,
                 hidden_activation=tf.nn.relu,
                 output_activation=tf.nn.sigmoid,
                 hidden_dropout=None,
                 initializer=None):

    x = inputs
    for units in hidden_units:
        x = tf.layers.dense(x,
                            units,
                            activation=hidden_activation,
                            kernel_initializer=initializer)

        if hidden_dropout is not None:
            x = tf.layers.dropout(x, rate=hidden_dropout)

    outputs = tf.layers.dense(x, 1, kernel_initializer=initializer)

    if output_activation is not None:
        outputs = output_activation(outputs)
    return outputs

In [None]:
def multi_task(inputs,
               num_tasks,
               task_hidden_units,
               task_output_activations,
               **kwargs):

    outputs = []

    for i in range(num_tasks):

        task_inputs = inputs[i] if isinstance(inputs, list) else inputs

        output = task_network(task_inputs,
                              task_hidden_units,
                              output_activation=task_output_activations[i],
                              **kwargs)
        outputs.append(output)

    return outputs

## Mixture-of-experts

In [None]:
def _synthetic_data(num_examples, example_dim=100, c=0.3, p=0.8, m=5):

    mu1 = np.random.normal(size=example_dim)
    mu1 = (mu1 - np.mean(mu1)) / (np.std(mu1) * np.sqrt(example_dim))

    mu2 = np.random.normal(size=example_dim)
    mu2 -= mu2.dot(mu1) * mu1
    mu2 /= np.linalg.norm(mu2)

    w1 = c * mu1
    w2 = c * (p * mu1 + np.sqrt(1. - p ** 2) * mu2)

    alpha = np.random.normal(size=m)
    beta = np.random.normal(size=m)

    examples = np.random.normal(size=(num_examples, example_dim))

    w1x = np.matmul(examples, w1)
    w2x = np.matmul(examples, w2)

    sin1, sin2 = 0., 0.
    for i in range(m):
        sin1 += np.sin(alpha[i] * w1x + beta[i])
        sin2 += np.sin(alpha[i] * w2x + beta[i])

    y1 = w1x + sin1 + np.random.normal(size=num_examples, scale=0.01)
    y2 = w2x + sin2 + np.random.normal(size=num_examples, scale=0.01)

    return examples.astype(np.float32), (y1.astype(np.float32), y2.astype(np.float32))

In [None]:
def synthetic_data_input_fn(num_examples, epochs=1, batch_size=256, buffer_size=256, **kwargs):

    synthetic_data = _synthetic_data(num_examples, **kwargs)

    dataset = tf.data.Dataset.from_tensor_slices(synthetic_data)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(epochs)
    dataset = dataset.prefetch(buffer_size)

    return dataset

In [None]:
def gating_network(inputs, num_experts, expert_index=None):
    """
    Gating network: y = SoftMax(W * inputs)
    :param inputs: tf.Tensor
    :param num_experts: Int > 0, number of expert networks.
    :param expert_index: Int, index of expert network.
    :return: tf.Tensor
    """

    x = tf.layers.dense(inputs,
                        units=num_experts,
                        use_bias=False,
                        name="expert{}_gate".format(expert_index))

    return tf.nn.softmax(x)

In [None]:
def one_gate(inputs,
             num_tasks,
             num_experts,
             task_hidden_units,
             task_output_activations,
             expert_hidden_units,
             expert_hidden_activation=tf.nn.relu,
             task_hidden_activation=tf.nn.relu,
             task_initializer=None,
             task_dropout=None):

    experts_gate = gating_network(inputs, num_experts)

    experts_outputs = []
    for i in range(num_experts):
        x = inputs
        for j, units in enumerate(expert_hidden_units):
            x = tf.layers.dense(x, units, activation=expert_hidden_activation, name="expert{}_dense{}".format(i, j))
        experts_outputs.append(x)

    experts_outputs = tf.stack(experts_outputs, axis=1)
    experts_selector = tf.expand_dims(experts_gate, axis=1)

    outputs = tf.linalg.matmul(experts_selector, experts_outputs)

    multi_task_inputs = tf.squeeze(outputs)

    return multi_task(multi_task_inputs,
                      num_tasks,
                      task_hidden_units,
                      task_output_activations,
                      hidden_activation=task_hidden_activation,
                      hidden_dropout=task_dropout,
                      initializer=task_initializer)

In [None]:
def multi_gate(inputs,
               num_tasks,
               num_experts,
               task_hidden_units,
               task_output_activations,
               expert_hidden_units,
               expert_hidden_activation=tf.nn.relu,
               task_hidden_activation=tf.nn.relu,
               task_initializer=None,
               task_dropout=None):

    experts_outputs = []
    for i in range(num_experts):
        x = inputs
        for j, units in enumerate(expert_hidden_units[:-1]):
            x = tf.layers.dense(x, units, activation=expert_hidden_activation, name="expert{}_dense{}".format(i, j))

        x = tf.layers.dense(x, expert_hidden_units[-1], name="expert{}_out".format(i))

        experts_outputs.append(x)

    experts_outputs = tf.stack(experts_outputs, axis=1)

    outputs = []
    for i in range(num_experts):
        expert_gate = gating_network(inputs, num_experts, expert_index=i)
        expert_selector = tf.expand_dims(expert_gate, axis=1)

        output = tf.linalg.matmul(expert_selector, experts_outputs)

        outputs.append(tf.squeeze(output))

    return multi_task(outputs,
                      num_tasks,
                      task_hidden_units,
                      task_output_activations,
                      hidden_activation=task_hidden_activation,
                      hidden_dropout=task_dropout,
                      initializer=task_initializer)

## Testing

In [None]:
from absl.testing import parameterized
import sys

In [None]:
tf.disable_eager_execution()
sys.dont_write_bytecode = True
sys.argv = sys.argv[:1]
old_sysexit = sys.exit
tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
class TestMixtureOfExperts(tf.test.TestCase, parameterized.TestCase):

    @parameterized.parameters(42, 256, 1024, 2021)
    def test_synthetic_data(self, random_seed):
        np.random.seed(random_seed)
        _, (y1, y2) = _synthetic_data(1000, p=0.8)
        cor = np.corrcoef(y1, y2)
        print(cor)

    def test_one_gate(self):

        num_examples = 1000
        example_dim = 128

        inputs = tf.random.normal(shape=(num_examples, example_dim))

        outputs = one_gate(inputs,
                       num_tasks=2,
                       num_experts=3,
                       task_hidden_units=[10, 5],
                       task_output_activations=[None, None],
                       expert_hidden_units=[64, 32],
                       expert_hidden_activation=tf.nn.relu,
                       task_hidden_activation=tf.nn.relu,
                       task_initializer=None,
                       task_dropout=None)

    def test_multi_gate(self):

        num_examples = 1000
        example_dim = 128

        inputs = tf.random.normal(shape=(num_examples, example_dim))

        outputs = multi_gate(inputs,
                       num_tasks=2,
                       num_experts=3,
                       task_hidden_units=[10, 5],
                       task_output_activations=[None, None],
                       expert_hidden_units=[64, 32],
                       expert_hidden_activation=tf.nn.relu,
                       task_hidden_activation=tf.nn.relu,
                       task_initializer=None,
                       task_dropout=None)

In [None]:
try:
    sys.exit = lambda *args: None
    tf.test.main()
finally:
    sys.exit = old_sysexit

Running tests under Python 3.7.11: /usr/bin/python3
[ RUN      ] TestMixtureOfExperts.test_multi_gate
[       OK ] TestMixtureOfExperts.test_multi_gate
[ RUN      ] TestMixtureOfExperts.test_one_gate
[       OK ] TestMixtureOfExperts.test_one_gate
[ RUN      ] TestMixtureOfExperts.test_session
[  SKIPPED ] TestMixtureOfExperts.test_session
[ RUN      ] TestMixtureOfExperts.test_synthetic_data0 (42)
[       OK ] TestMixtureOfExperts.test_synthetic_data0 (42)
[ RUN      ] TestMixtureOfExperts.test_synthetic_data1 (256)
[       OK ] TestMixtureOfExperts.test_synthetic_data1 (256)
[ RUN      ] TestMixtureOfExperts.test_synthetic_data2 (1024)
[       OK ] TestMixtureOfExperts.test_synthetic_data2 (1024)
[ RUN      ] TestMixtureOfExperts.test_synthetic_data3 (2021)
[       OK ] TestMixtureOfExperts.test_synthetic_data3 (2021)


[[1.         0.81089216]
 [0.81089216 1.        ]]
[[1.         0.79510139]
 [0.79510139 1.        ]]
[[1.         0.81083881]
 [0.81083881 1.        ]]
[[1.         0.80292345]
 [0.80292345 1.        ]]


----------------------------------------------------------------------
Ran 7 tests in 0.548s

OK (skipped=1)
