Implementation of SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER: https://arxiv.org/pdf/1701.06538

## Gating Network
For each expert, the gate will output a score which will determine of the expert gets activated or not. If the output for any particular expert is 0 from the gate, $E_i (x)$ is not computed.

Softmax Gating: $G_\sigma (x) = Softmax(x \cdot W_g) $

In [None]:
import torch
import torch.nn as nn


class SoftmaxGating(nn.Module):
    def __init__(self, input_emb_dim, num_experts):
        super().__init__()
        self.gate = nn.Linear(input_emb_dim, num_experts)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.gate(x)) 

Noisy Top-K Gating: $$ G(x) = Softmax(KeepTopK(H(x), k)) $$ $$H(x)_i = (x \cdot W_g)_i + StandardNormal() \cdot Softplus((x \cdot W_{noise})_i) $$ $$ KeepTopK(v,k)_i = v_i \text{   if    } v_i \text{  is in the top   } k \text{  elements of } v, -\inf \text{otherwise} $$

In [None]:
class NoisyTopKGating(nn.Module):
    def __init__(self, input_emb_dim, num_experts, k):
        super().__init__()
        self.input_emb_dim = input_emb_dim
        self.num_experts = num_experts
        self.k = k

        self.gate = nn.Linear(input_emb_dim, num_experts)
        self.noise = nn.Linear(input_emb_dim, num_experts)

        self.softmax = nn.Softmax(dim=-1)
        self.softplus = nn.Softplus(dim=-1)

    def forward(self, x):
        gate_output = self.gate(x)
        noise_tensor = torch.randn(self.input_emb_dim, self.num_experts)
        tuneable_noise = torch.matmul(self.standard_normal, self.softplus(self.noise(x)))
