In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Union, Callable

In [34]:
class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.

    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    TransformerEncoderLayer can handle either traditional torch.tensor inputs,
    or Nested Tensor inputs.  Derived classes are expected to similarly accept
    both input formats.  (Not all combinations of inputs are currently
    supported by TransformerEncoderLayer while Nested Tensor is in prototype
    state.)

    If you are implementing a custom layer, you may derive it either from
    the Module or TransformerEncoderLayer class.  If your custom layer
    supports both torch.Tensors and Nested Tensors inputs, make its
    implementation a derived class of TransformerEncoderLayer. If your custom
    Layer supports only torch.Tensor inputs, derive its implementation from
    Module.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectively. Otherwise it's done after. Default: ``False`` (after).
        bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
            bias. Default: ``True``.

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)

    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)

    Fast path:
        forward() will use a special optimized implementation described in
        `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
        conditions are met:

        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)

        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.

        .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
         https://arxiv.org/abs/2205.14135

    """

    __constants__ = ['norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 bias: bool = True, device=None, dtype=None, num_experts: int = 8, num_experts_per_tok: int = 2) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                                            bias=bias, batch_first=batch_first,
                                            **factory_kwargs)
        
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        #creating the experts - 
        self.experts = [expert(d_model , dim_feedforward, bias) for i in range(num_experts)]
        #creating gatingnetwork
        self.gatingNetwork = GatingNetwork(input_dim , num_experts)
        self.num_experts_per_tok = num_experts_per_tok
        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = _get_activation_fn(activation)

        # We can't test self.activation in forward() in TorchScript,
        # so stash some information about it instead.
        if activation is F.relu or isinstance(activation, torch.nn.ReLU):
            self.activation_relu_or_gelu = 1
        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
            self.activation_relu_or_gelu = 2
        else:
            self.activation_relu_or_gelu = 0
        self.activation = activation

    def __setstate__(self, state):
        super().__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu


    def forward(
            self,
            src: Tensor,
            src_mask: Optional[Tensor] = None,
            src_key_padding_mask: Optional[Tensor] = None,
            is_causal: bool = False) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
            is_causal: If specified, applies a causal mask as ``src mask``.
                Default: ``False``.
                Warning:
                ``is_causal`` provides a hint that ``src_mask`` is the
                causal mask. Providing incorrect hints can result in
                incorrect execution, including forward and backward
                compatibility.

        Shape:
            see the docs in :class:`~torch.nn.Transformer`.
        """
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
            x = self.norm2(x + self._ff_block(x))

        return x


    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False, is_causal=is_causal)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: Tensor , num_experts_per_tok) -> Tensor:
        gating_scores = F.softmax(self.gatingnetwork(x) , dim = -1)
        top_scores, top_indices = gating_scores.topk(self.num_experts_per_tok , dim = -1 , sorted = False) 
        
        top_mask = torch.zeros_like(gating_scores)
        top_mask = top_mask.scatter(-1, top_indices, 1)
        output = []
        
        for i in range(len(top_mask)):
            if(top_mask[i] == 1):
                output.append(top_scores[i]*self.experts[i](x))
        
        result = torch.sum(torch.stack(output) , dim = 0)
        return result
#         x = self.linear2(self.dropout(self.activation(self.linear1(x))))
#         return self.dropout2(x)


    class GatingNetwork(nn.Module):
        def __init__(self, input_dim, num_experts):
            super(GatingNetwork, self).__init__()
            self.gate = nn.Linear(input_dim, num_experts)

        def forward(self, x):
            return F.softmax(self.gate(x), dim=-1)
    class expert(nn.Module):
        def __init__(self,d_model , dim_feedforward , bias):
            factory_kwargs = {'device': device, 'dtype': dtype}
            self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
            self.dropout = Dropout(dropout)
            self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
            self.dropout2 = Dropout(dropout)
        def forward(self ,x):
            x = self.linear2(self.dropout(self.activation(self.linear1(x))))
            return self.dropout2(x)

In [18]:
# """
# This model integrates the MoE concept within a Transformer architecture. Each token's
# representation is processed by a subset of experts, determined by the gating mechanism.
# This architecture allows for efficient and specialized handling of different aspects of the
# data, aiming for the adaptability and efficiency noted in the Mixtral 8x7B model's design
# philosophy. The model activates only a fraction of the available experts for each token,
# significantly reducing the computational resources needed compared to activating all experts
# for all tokens.
# """

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# # Define the Expert class
# class Expert(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(Expert, self).__init__()
#         self.fc1 = nn.Linear(input_dim, hidden_dim)
#         self.fc2 = nn.Linear(hidden_dim, output_dim)

#     def forward(self, x):
#         x = F.relu(self.fc1(x))
#         return self.fc2(x)

# # Define the Gating Network class
# class GatingNetwork(nn.Module):
#     def __init__(self, input_dim, num_experts):
#         super(GatingNetwork, self).__init__()
#         self.gate = nn.Linear(input_dim, num_experts)

#     def forward(self, x):
#         return F.softmax(self.gate(x), dim=-1)

# # Define the Mixture of Experts Layer class
# class MoELayer(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_experts):
#         super(MoELayer, self).__init__()
#         self.experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
#         self.gate = GatingNetwork(input_dim, num_experts)

#     def forward(self, x, num_experts_per_tok):
#         gating_scores = self.gate(x)
#         topk_gating_scores, topk_indices = gating_scores.topk(num_experts_per_tok, dim=2, sorted=False)
#         # Create a mask to zero out the contributions of non-topk experts
#         mask = torch.zeros_like(gating_scores).scatter_(2, topk_indices, 1)
#         # Use the mask to retain only the topk gating scores
#         gating_scores = gating_scores * mask
#         # Normalize the gating scores to sum to 1 across the selected top experts
#         gating_scores = F.normalize(gating_scores, p=1, dim=2)
        
#         expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
#         expert_outputs = expert_outputs.transpose(1, 2)
#         output = torch.einsum('bte,bteo->bto', gating_scores, expert_outputs)
#         return output

# # Define the overall Transformer model with integrated MoE
# class TransformerWithMoE(nn.Module):
#     def __init__(self, num_layers, dim, head_dim, hidden_dim, n_heads, num_experts, vocab_size, num_experts_per_tok):
#         super(TransformerWithMoE, self).__init__()
#         self.num_experts_per_tok = num_experts_per_tok
#         self.embedding = nn.Embedding(vocab_size, dim)
#         self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads) for _ in range(num_layers)])
#         self.moe_layer = MoELayer(dim, hidden_dim, dim, num_experts)
#         self.output_layer = nn.Linear(dim, vocab_size)

#     def forward(self, x):
#         x = self.embedding(x)
#         print('after embedding',x.shape)
#         for layer in self.layers:
#             print('before layers', x.shape)
#             x = layer(x)
#         print('after all layers', x.shape)
#         x = self.moe_layer(x, self.num_experts_per_tok)
#         logits = self.output_layer(x)
#         return logits

# # Initialize the model with configurations matching Mixtral 8x7B
# model = TransformerWithMoE(
#     num_layers=2,              # Number of transformer layers
#     dim=4096,                   # Dimension of the model
#     head_dim=128,               # Dimension of each head in the multi-head attention mechanisms
#     hidden_dim=146,           # Hidden dimensionality in the feed-forward network within the transformer
#     n_heads=32,                 # Number of attention heads
#     num_experts=8,              # Number of experts in the MoE layer
#     vocab_size=32000,           # Vocabulary size for the embedding layer
#     num_experts_per_tok=2       # Number of experts activated per token
# )
#model(torch.randint(10,(40,)))