## Code implementation
### Version: 0 [17-08-2024]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import KFold
import numpy as np

# Define the Expert Layer
class Expert(nn.Module):
    '''
    Class Name: Expert
    Description: 
        The `Expert` class represents an individual expert within the Mixture-of-Experts (MoE) model. 
        Each expert is a simple neural network layer that processes input data and applies a ReLU activation function. 
        The experts in the MoE model are responsible for learning different aspects of the input features.
    
    Attributes:
        fc (torch.nn.Linear): A fully connected (linear) layer that takes the input dimension and outputs a transformed feature.

    Methods:
        forward(x): 
            Description: Applies a linear transformation followed by a ReLU activation function to the input tensor `x`.
            Input: x (torch.Tensor): The input data tensor.
            Output: torch.Tensor: The transformed data tensor after applying the linear layer and ReLU.
    '''
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return F.relu(self.fc(x))

# Define the Gating Network
class GatingNetwork(nn.Module):
    '''
    Class Name: GatingNetwork
    Description: 
        The `GatingNetwork` class is a critical component of the Mixture-of-Experts (MoE) model. 
        It acts as a softmax classifier that determines the weighting of each expert's output based on the input data. 
        This class enables the MoE model to activate different experts dynamically depending on the input.

    Attributes:
        fc (torch.nn.Linear): A fully connected (linear) layer that produces the logits used to determine expert weights.

    Methods:
        forward(x): 
            Description: Passes the input tensor `x` through a linear layer and applies a softmax function to produce a probability distribution over the experts.
            Input: x (torch.Tensor): The input data tensor.
            Output: torch.Tensor: A tensor representing the probabilities of selecting each expert.
    '''
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        return F.softmax(self.fc(x), dim=-1)

# Define the Mixtures-of-Experts Model
class MoEModel(nn.Module):
    '''
    Class Name: MoEModel
    Description: 
        The `MoEModel` class represents the overall Mixture-of-Experts (MoE) model. 
        It combines the outputs of several `Expert` instances based on the gating network's decision.
        The model processes input data by passing it through the gating network, 
        which assigns weights to the outputs of each expert. The weighted outputs are then summed 
        to form the final output, which is passed through a final linear layer for classification.

    Attributes:
        experts (torch.nn.ModuleList): A list of `Expert` instances that process the input data.
        gating_network (GatingNetwork): The network that determines the weighting of each expert's output.
        fc_out (torch.nn.Linear): A final fully connected layer that produces the output prediction.

    Methods:
        forward(x): 
            Description: Passes the input tensor `x` through the gating network to obtain expert weights,
                         then applies each expert to `x`, combines their outputs based on the gating weights, 
                         and finally applies a fully connected layer to produce the output.
            Input: x (torch.Tensor): The input data tensor.
            Output: torch.Tensor: The final output tensor after combining the expert outputs.
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts):
        super(MoEModel, self).__init__()
        self.experts = nn.ModuleList([Expert(input_dim, hidden_dim) for _ in range(num_experts)])
        self.gating_network = GatingNetwork(input_dim, num_experts)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        gate_outputs = self.gating_network(x)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
        weighted_expert_output = torch.sum(expert_outputs * gate_outputs.unsqueeze(2), dim=1)
        output = self.fc_out(weighted_expert_output)
        return output

# Hyperparameters
input_dim = 784  # For MNIST dataset (28x28 images)
hidden_dim = 128
output_dim = 10  # Number of classes in MNIST
num_experts_list = [2, 4, 6, 8]  # List of different numbers of experts to test

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Set up K-Fold Cross Validation
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

# Cross-validation loop
best_num_experts = None
best_loss = float('inf')

for num_experts in num_experts_list:
    fold_losses = []
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'FOLD {fold + 1}/{k_folds} - Num Experts: {num_experts}')

        # Sample elements randomly from a given list of indices
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

        # Define data loaders for training and validation
        train_loader = DataLoader(dataset, batch_size=64, sampler=train_subsampler)
        val_loader = DataLoader(dataset, batch_size=64, sampler=val_subsampler)

        # Instantiate the model
        model = MoEModel(input_dim, hidden_dim, output_dim, num_experts)
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # Training loop
        num_epochs = 5
        for epoch in range(num_epochs):
            model.train()
            for images, labels in train_loader:
                images = images.view(-1, 28 * 28)  # Flatten the images

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.view(-1, 28 * 28)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        fold_losses.append(val_loss / len(val_loader))
        print(f'Fold {fold + 1} - Validation Loss: {val_loss / len(val_loader):.4f}')

    # Calculate average loss over all folds
    avg_loss = np.mean(fold_losses)
    print(f'Average Validation Loss for {num_experts} experts: {avg_loss:.4f}')

    # Update best number of experts
    if avg_loss < best_loss:
        best_loss = avg_loss
        best_num_experts = num_experts

print(f'Best number of experts: {best_num_experts} with average validation loss: {best_loss:.4f}')

# Final training with the best number of experts
model = MoEModel(input_dim, hidden_dim, output_dim, best_num_experts)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.view(-1, 28 * 28)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Final training completed with the best number of experts.")

# Save the final model
torch.save(model.state_dict(), 'moe_best_model.pth')
