# Code implementation
### Author: Oscar Escudero Arnanz

## Transformers "final version"
### Version: 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import KFold
import numpy as np
import matplotlib.pyplot as plt

# Model parameters that remain constant
input_shape = (32, 32, 3)
patch_size = 4
num_patches = (input_shape[0] // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_layers = 4
mlp_head_units = [128, 64]
num_classes = 10
dropout_rate = 0.1
num_epochs = 50

# Different numbers of experts to evaluate
num_experts_list = [5, 6, 7, 8]

# Load and preprocess CIFAR-10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset, _ = torch.utils.data.random_split(trainset, [5000, len(trainset) - 5000])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testset, _ = torch.utils.data.random_split(testset, [1000, len(testset) - 1000])

# Define the MoE (Mixture of Experts) class
class MoE(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, dropout_rate):
        super(MoE, self).__init__()
        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        ) for _ in range(num_experts)])
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        gate_values = self.gating_network(x)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1)
        output = torch.sum(gate_values.unsqueeze(-2) * expert_outputs, dim=-1)
        return output

# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, num_patches, projection_dim, patch_dim):
        super(PatchEmbedding, self).__init__()
        self.projection = nn.Linear(patch_dim, projection_dim)
        self.position_embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, patches):
        positions = torch.arange(0, patches.size(1), device=patches.device).unsqueeze(0)
        return self.projection(patches) + self.position_embedding(positions)

# Transformer Encoder with MoE
class TransformerEncoder(nn.Module):
    def __init__(self, projection_dim, num_heads, ff_dim, num_experts, dropout_rate):
        super(TransformerEncoder, self).__init__()
        self.layer_norm1 = nn.LayerNorm(projection_dim)
        self.attention = nn.MultiheadAttention(embed_dim=projection_dim, num_heads=num_heads, dropout=dropout_rate)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.layer_norm2 = nn.LayerNorm(projection_dim)
        self.moe = MoE(projection_dim, ff_dim, num_experts, dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x_norm1 = self.layer_norm1(x)
        attention_output, _ = self.attention(x_norm1, x_norm1, x_norm1)
        x = x + self.dropout1(attention_output)
        
        x_norm2 = self.layer_norm2(x)
        moe_output = self.moe(x_norm2)
        return x + self.dropout2(moe_output)

# Vision Transformer with MoE model
class VisionTransformer(nn.Module):
    def __init__(self, input_shape, patch_size, num_patches, projection_dim, num_heads, transformer_layers, mlp_head_units, num_classes, dropout_rate, num_experts):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        patch_dim = input_shape[2] * patch_size * patch_size

        self.patch_embedding = PatchEmbedding(num_patches, projection_dim, patch_dim)

        self.transformer_layers = nn.ModuleList([
            TransformerEncoder(projection_dim, num_heads, projection_dim, num_experts, dropout_rate)
            for _ in range(transformer_layers)
        ])

        self.layer_norm = nn.LayerNorm(projection_dim)
        self.mlp_head = MoE(projection_dim * num_patches, mlp_head_units[-1], num_experts, dropout_rate)
        self.classifier = nn.Linear(mlp_head_units[-1], num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.extract_patches(x)
        x = self.patch_embedding(x)
        
        for layer in self.transformer_layers:
            x = layer(x)
        
        x = self.layer_norm(x)
        x = x.flatten(1)
        x = self.mlp_head(x)
        x = self.classifier(x)
        return x

    def extract_patches(self, images):
        batch_size = images.size(0)
        patches = images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(batch_size, -1, self.patch_size * self.patch_size * images.size(1))
        return patches

# Configure the device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Number of folds for cross-validation
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

# Convert data to tensors
X_data = np.array([trainset[i][0].numpy() for i in range(len(trainset))])
y_data = np.array([trainset[i][1] for i in range(len(trainset))])

# Evaluate for each value of num_experts
results = []

for num_experts in num_experts_list:
    print(f"\nEvaluating num_experts = {num_experts}\n")
    
    all_train_accuracies = []
    all_val_accuracies = []
    all_train_losses = []
    all_val_losses = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_data)):
        print(f"Fold {fold + 1}/{n_splits}")
        
        # Training and validation subsets for this fold
        X_train_fold = torch.tensor(X_data[train_idx], dtype=torch.float32)
        y_train_fold = torch.tensor(y_data[train_idx], dtype=torch.long)
        X_val_fold = torch.tensor(X_data[val_idx], dtype=torch.float32)
        y_val_fold = torch.tensor(y_data[val_idx], dtype=torch.long)
        
        # Create loaders
        trainloader_fold = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_train_fold, y_train_fold), batch_size=8, shuffle=True)
        valloader_fold = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_val_fold, y_val_fold), batch_size=8, shuffle=False)
        
        # Initialize the model
        model = VisionTransformer(
            input_shape=input_shape,
            patch_size=patch_size,
            num_patches=num_patches,
            projection_dim=projection_dim,
            num_heads=num_heads,
            transformer_layers=transformer_layers,
            mlp_head_units=mlp_head_units,
            num_classes=num_classes,
            dropout_rate=dropout_rate,
            num_experts=num_experts  # Evaluated hyperparameter
        )
        model.to(device)
        
        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        # Store losses and accuracies for plotting
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []

        # Train the model for this fold
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0

            for i, (inputs, labels) in enumerate(trainloader_fold):
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

            train_loss = running_loss / len(trainloader_fold)
            train_accuracy = 100 * correct_train / total_train

            train_losses.append(train_loss)
            train_accuracies.append(train_accuracy)

            # Evaluate on the validation set for this fold
            model.eval()
            running_val_loss = 0.0
            correct_val = 0
            total_val = 0
            with torch.no_grad():
                for inputs, labels in valloader_fold:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    running_val_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    total_val += labels.size(0)
                    correct_val += (predicted == labels).sum().item()

            val_loss = running_val_loss / len(valloader_fold)
            val_accuracy = 100 * correct_val / total_val

            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)

        all_train_accuracies.append(train_accuracies)
        all_val_accuracies.append(val_accuracies)
        all_train_losses.append(train_losses)
        all_val_losses.append(val_losses)

    # Average results for the evaluated number of experts
    avg_train_accuracy = np.mean([acc[-1] for acc in all_train_accuracies])
    avg_val_accuracy = np.mean([acc[-1] for acc in all_val_accuracies])
    results.append((num_experts, avg_train_accuracy, avg_val_accuracy))
    print(f"num_experts = {num_experts}, Avg Train Accuracy: {avg_train_accuracy:.2f}%, Avg Val Accuracy: {avg_val_accuracy:.2f}%")

    # Plotting losses and accuracies for each fold
    epochs = range(1, num_epochs + 1)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, np.mean(all_train_losses, axis=0), label='Train loss')
    plt.plot(epochs, np.mean(all_val_losses, axis=0), label='Validation loss')
    plt.title(f'Loss for num_experts = {num_experts}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, np.mean(all_train_accuracies, axis=0), label='Train accuracy')
    plt.plot(epochs, np.mean(all_val_accuracies, axis=0), label='Validation accuracy')
    plt.title(f'Accuracy for num_experts = {num_experts}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

# Display results for all evaluated num_experts
print("\nFinal Results for different values of num_experts:")
for num_experts, train_acc, val_acc in results:
    print(f"num_experts = {num_experts}: Train Accuracy = {train_acc:.2f}%, Validation Accuracy = {val_acc:.2f}%")

# Select the best number of experts based on validation accuracy
best_num_experts = max(results, key=lambda x: x[2])[0]
print(f"\nBest number of experts based on validation accuracy: {best_num_experts}")

# Evaluate the best model on the test set
model = VisionTransformer(
    input_shape=input_shape,
    patch_size=patch_size,
    num_patches=num_patches,
    projection_dim=projection_dim,
    num_heads=num_heads,
    transformer_layers=transformer_layers,
    mlp_head_units=mlp_head_units,
    num_classes=num_classes,
    dropout_rate=dropout_rate,
    num_experts=best_num_experts  # Use the best number of experts
)
model.to(device)

testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False)
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"\nAccuracy of the best model on the 1000 test images: {test_accuracy:.2f}%")


Files already downloaded and verified
Files already downloaded and verified

Evaluating num_experts = 5

Fold 1/5
