# Modular Neural Networks - Moefying a ViT

Why should we use modular neural networks?

It is not always possible to enlarge models without running into hardware limits, just consider the case of edge computing. In these cases, a simple and effective solution to increase the capacity of the model without increasing its computational burden is to use Mixture of Experts.

This particular dynamic architecture allows to split a network vertically using an arbitrary number of experts. We refer to this particular typer of Neural Network also as Sparse Neural Networks as only a pre-established number _k_ out of _E_ experts are used for each token based on a routing mechanism implemented through the _gate layer_.

## The MoE Layer

We'll call $\{{E_i}\}$ the set of experts and $p_e$ the routing probabilities, the output of the layer is defined as:

$$
f(x) = \sum_{e} p_e E_e(x)
$$

A simple but effective representation of a MoE layer is the following:

<center>
<img src="https://drive.google.com/uc?id=1HSvgNOHX5W-v5ttTWotpSwWbTddzWrN4" width="600" height="400">
</center>

## The Gate Layer

<center>
<img src="https://drive.google.com/uc?id=1LBOykfoMvusGshHSF77Nj0s7xL0VrbfB" width="800" height="300">
</center>

To obtain the routing probabilities the most naïve solution consist in using a linear layer that projects the token dimension _N_ to the experts dimension _E_, by applying a *softmax* we obtain for each expert a probability and then we pick for each token the _Top-k_ experts:

$$
p_e = Top_{k}\big(softmax(Wx + ϵ)\big)
$$

$ϵ$ is Gaussian noise added in order to have a differentiable sampling mechanism as in the Gumbel-Softmax trick.




# Getting started with FastMoE

The extended tutorial to install the library can be found [here](https://github.com/laekov/fastmoe/blob/master/doc/installation-guide.md).

There are two options in the installation phase depending on the type of training we are likely to use (distributed or not), in this tutorial we'll consider the latter.

In general *FastMoE* allows two different distributed options that can be used independently or jointly (preferred), as one of the advantages of using a MoE is that it can be highly parallelized across different GPUs.

<div>
  <br>
  <img src="https://drive.google.com/uc?id=1HzedJ7RqziWK4Z1QASle9bpu1y7Elv6o" width="350" height="230" hspace="30">

  <img src="https://drive.google.com/uc?id=1HzedJ7RqziWK4Z1QASle9bpu1y7Elv6o" width="350" height="230">
</div>

In [1]:
# clone the repository
!git clone https://github.com/laekov/fastmoe.git

# move into the folder
%cd ./fastmoe

# install requirements and utilities
!pip install ninja dm-tree einops

# install, USE_NCCL set to zero disables the distributed features
!USE_NCCL=0 pip install .

# try to import
import torch
import torch.nn as nn
import fmoe

Cloning into 'fastmoe'...
remote: Enumerating objects: 2823, done.[K
remote: Counting objects: 100% (796/796), done.[K
remote: Compressing objects: 100% (485/485), done.[K
remote: Total 2823 (delta 325), reused 348 (delta 311), pack-reused 2027[K
Receiving objects: 100% (2823/2823), 964.69 KiB | 18.20 MiB/s, done.
Resolving deltas: 100% (1908/1908), done.
/content/fastmoe
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ninja
  Downloading ninja-1.11.1-py2.py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (145 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.0/146.0 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja, einops
Successfully installed einops-0.6.1 ninja-1.11.

# Building a FastMoE ViT

The code will be based on this repository: https://github.com/lucidrains/vit-pytorch.
We are going to use the basic implementation of a ViT.

To obtain a MoE we are going to replace the Feed Forward network of the Transformer Block with a mixture of Feed Forward networks, the parameters of the layer are:
- **_num_experts_**: the number of experts for each worker (each GPU, in our case just 1)
- **_d\_model_**: the input and output dimension of the layer
- **_d\_hidden_**: the hidden dimension of the first linear layer
- **_top_k_**: the number of experts selected for each token
- **_gate_**: the type of routing strategy, all the available gates are present here ([gates](https://github.com/laekov/fastmoe/tree/master/fmoe/gates))

To switch from a ViT to a MoE ViT we will use the flag _moefy_ and specify the parameters defined above, everything else does not require any modifications.

Note that for simplicity, instead of implementing it by hand, FastMoE already has implemented the TransformerMLP and we only need to remember that it gives and **output dimension** equal to the **input_dimension**.
```
from fmoe.transformer import FMoETransformerMLP

# FMoE Transformer MLP
self.mlp = FMoETransformerMLP(
                  num_expert, 
                  d_model = dim
                  d_hidden = hidden_dim.
                  top_k,
                  gate
                )

# if we want output dimension different from input, just add a linear layer

self.projection = nn.Linear(dim, output_dim)
```

In [29]:
#@title MoE FeedForward
from fmoe.layers import FMoE
from fmoe.linear import FMoELinear
from fmoe.gates.gshard_gate import GShardGate

class _Expert(nn.Module):
    def __init__(self, num_expert, d_model, d_hidden, activation, dropout=0., rank=0):
        super().__init__()
        self.linear1 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.linear2 = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        x = self.linear1(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x, fwd_expert_count)
        x = self.dropout(x)
        return x

class FeedForwardMoE(FMoE):
    r"""
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
    """

    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        activation=torch.nn.GELU(),
        dropout=0.,
        expert_dp_comm="none",
        expert_rank=0,
        **kwargs
    ):
        def one_expert(d_model):
            return _Expert(1, d_model, d_hidden, activation, dropout=dropout, rank=0)
        
        expert = one_expert
        super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
        self.mark_parallel_comm(expert_dp_comm)

    def forward(self, inp: torch.Tensor):
        r"""
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
        """
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
        output = super().forward(inp)
        return output.reshape(original_shape)

In [30]:
#@title ViT Implementation
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(
        self, 
        dim, 
        depth, 
        heads, 
        dim_head,
        mlp_dim,
        dropout = 0.,
        moefy = False,
        num_expert = None,
        top_k = None,
        gate = None
        ):
      
        super().__init__()
        if moefy:
          assert num_expert and top_k and gate, \
          "If 'moefy' is set to True but none of the following arguments should be None: num_expert={}, num_gpus={}, top_k={}, gate={}" \
          .format(num_expert, top_k, gate)
          
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForwardMoE(num_expert, dim, mlp_dim, dropout=dropout, top_k=top_k, gate=GShardGate) 
                if moefy else FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
            
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(
        self, *, 
        image_size, 
        patch_size, 
        num_classes, 
        dim, depth, 
        heads, 
        mlp_dim, 
        pool = 'cls', 
        channels = 3, 
        dim_head = 64, 
        dropout = 0., 
        emb_dropout = 0.,
        moefy = False,
        num_expert = None,
        top_k = None,
        gate = None
        ):
      
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, moefy, num_expert, top_k, gate)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

# Number of parameters comparison

In [31]:
#@title ViT
vit = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=32,
    depth=2,
    heads=2,
    mlp_dim=32
    ).cuda()

print("Number of parameters in standard ViT:", sum(p.numel() for p in vit.parameters() if p.requires_grad))

Number of parameters in standard ViT: 41546


In [32]:
#@title MoE ViT
moe_vit = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=32,
    depth=2,
    heads=2,
    mlp_dim=32,
    moefy = True,
    num_expert = 4,
    top_k = 2,
    gate = GShardGate
    ).cuda()

print("Number of parameters in MoE ViT:", sum(p.numel() for p in moe_vit.parameters() if p.requires_grad))

Number of parameters in MoE ViT: 54482
