In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


In [None]:
vocab_size=20000
num_tokens_per_example=200
(x_train,y_train),(x_val,y_val)=keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz


  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])


25000 Training sequences
25000 Validation sequences


  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])


In [None]:
x_train

array([list([1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 19193, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 10311, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 12118, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]),
       list([1, 194, 1153, 194, 82

In [None]:
x_train=tf.keras.preprocessing.sequence.pad_sequences(x_train,maxlen=num_tokens_per_example)
x_val=tf.keras.preprocessing.sequence.pad_sequences(x_val,maxlen=num_tokens_per_example)

In [None]:
embed_dim=32
num_heads=2
ff_dim=32
num_experts=10
batch_size=50
learning_rate=0.25
num_epochs=3
num_tokens_per_batch=(batch_size*num_tokens_per_example)
print(f"Number of tokens per batch: {num_tokens_per_batch}")

Number of tokens per batch: 10000


In [None]:
class TokenAndPositionEmbedding(layers.Layer):
  def __init__(self,maxlen,vocab_size,embed_dim):
    super(TokenAndPositionEmbedding,self).__init__()
    self.token_emb=layers.Embedding(input_dim=vocab_size,output_dim=embed_dim)
    self.pos_emb=layers.Embedding(input_dim=maxlen,output_dim=embed_dim)
  def call(self,x):
    maxlen=tf.shape(x)[-1]
    positions=tf.range(start=0,limit=maxlen,delta=1)
    positions=self.pos_emb(positions)
    x=self.token_emb(x)
    return x+positions

In [None]:
def create_feedforward_network(ff_dim,name=None):
  return keras.Sequential([layers.Dense(ff_dim,activation="relu"),layers.Dense(ff_dim)],name=name)

In [None]:
def load_balanced_loss(router_probs,expert_mask):
  num_experts=tf.shape(expert_mask)[-1]
  density=tf.reduce_mean(expert_mask,axis=0)
  density_proxy=tf.reduce_mean(router_probs,axis=0)
  loss=tf.reduce_mean(density_proxy*density)*tf.cast((num_experts**2),tf.dtypes.float32)
  return loss

In [None]:
class Router(layers.Layer):
  def __init__(self,num_experts,expert_capacity):
    self.num_experts=num_experts
    self.route=layers.Dense(units=num_experts)
    self.expert_capacity=expert_capacity
    super(Router,self).__init__()
  def call(self,inputs,training=False):
    router_logits=self.route(inputs)
    if training:
      router_logits+=tf.random.uniform(shape=router_logits.shape,minval=0.9,maxval=1.1)
    router_probs=keras.activations.softmax(router_logits,axis=-1)
    expert_gate,expert_index=tf.math.top_k(router_probs,k=1)
    expert_mask=tf.one_hot(expert_index,depth=self.num_experts)
    aux_loss=load_balanced_loss(router_probs,expert_mask)
    self.add_loss(aux_loss)
    position_in_expert=tf.cast(tf.math.cumsum(expert_mask,axis=0)*expert_mask,tf.dtypes.int32)
    expert_mask*=tf.cast(tf.math.less(tf.cast(position_in_expert,tf.dtypes.int32),self.expert_capacity),tf.dtypes.float32,)
    expert_mask_flat=tf.reduce_sum(expert_mask,axis=-1)
    expert_gate*=expert_mask_flat
    combined_tensor = tf.expand_dims(expert_gate* expert_mask_flat* tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),-1,) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)
    dispatch_tensor=tf.cast(combined_tensor,tf.dtypes.float32)
    return dispatch_tensor,combined_tensor

In [None]:
class Switch(layers.Layer):
  def __init__(self,num_experts,embed_dim,num_tokens_per_batch,capacity_factor=1):
    self.num_experts=num_experts
    self.embed_dim=embed_dim
    self.experts=[create_feedforward_network_networks(embed_dim) for _ in range(num_experts)]
    self.expert_capacity=num_tokens_per_batch//self.num_experts
    self.router=Router(self.num_experts,self.expert_capacity)
    super(Switch,self).__init__()
  def call(self,inputs):
    batch_size=tf.shape(inputs)[0]
    num_tokens_per_example=tf.shape(inputs)[1]
    inputs=tf.reshape(inputs,[num_tokens_per_batch,self.embed_dim])
    dispatched_tensor,combine_tensor=self.router(inputs)
    expert_inputs=tf.einsum("ab,acd->cdb",inputs,dispatch_tensor)
    expert_inputs=tf.reshape(expert_inputs,[self.num_experts,self.expert_capacity,self.embed_dim])
    expert_input_list=tf.unstack(expert_inputs,axis=0)
    expert_output_list=[self.experts[idx](expert_input) for idx,expert_input in enumerate(expert_input_list)]
    expert_outputs=tf.stack(expert_output_list,axis=1)
    expert_output_combined=tf.einsum("abc,xba->xc",expert_outputs,combine_tensor)
    outputs=tf.reshape(expert_outputs_combined,[batch_size,num_tokens_per_example,self.embed_dim])
    return outputs

In [None]:
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        # The ffn can be either a standard feedforward network or a switch
        # layer with a Mixture of Experts.
        self.ffn = ffn
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dropout2 = layers.Dropout(dropout_rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [None]:
def create_classifier():
    switch = Switch(num_experts, embed_dim, num_tokens_per_batch)
    transformer_block = TransformerBlock(ff_dim, num_heads, switch)

    inputs = layers.Input(shape=(num_tokens_per_example,))
    embedding_layer = TokenAndPositionEmbedding(
        num_tokens_per_example, vocab_size, embed_dim
    )
    x = embedding_layer(inputs)
    x = transformer_block(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(ff_dim, activation="relu")(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(2, activation="softmax")(x)

    classifier = keras.Model(inputs=inputs, outputs=outputs)
    return classifier


In [None]:
def run_experiment(classifier):
    classifier.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    history = classifier.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data=(x_val, y_val),
    )
    return history


classifier = create_classifier()
run_experiment(classifier)
