In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts):
        super(MoE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.expert_networks = nn.ModuleList([nn.LSTM(input_dim, hidden_dim) for _ in range(num_experts)])
        self.gating_network = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gating_logits = self.gating_network(x)
        expert_weights = F.softmax(gating_logits, dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.expert_networks], dim=-1)
        output = torch.sum(expert_weights.unsqueeze(-1) * expert_outputs, dim=-1)
        return output

In [6]:
class BaselineModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(BaselineModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        _, (hidden, _) = self.lstm(x)
        output = self.fc(hidden.squeeze(0))
        return output

class MoEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_experts, num_classes):
        super(MoEModel, self).__init__()
        self.lstm1 = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.moe = MoE(hidden_dim, hidden_dim, num_experts)
        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        _, (hidden1, _) = self.lstm1(x)
        moe_output = self.moe(hidden1.squeeze(0))
        _, (hidden2, _) = self.lstm2(moe_output.unsqueeze(0))
        output = self.fc(hidden2.squeeze(0))
        return output

In [15]:
from transformers import DataCollatorWithPadding

In [7]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from torch.utils.data import DataLoader

def evaluate(model, dataloader, device, criterion):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, dim=-1)
            total_correct += (predicted == labels).sum().item()
    accuracy = total_correct / len(dataloader.dataset)
    avg_loss = total_loss / len(dataloader)
    return avg_loss, accuracy

def plot_results(baseline_losses, baseline_accuracies, moe_losses, moe_accuracies):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(len(baseline_losses)), baseline_losses, label='Baseline')
    plt.plot(range(len(moe_losses)), moe_losses, label='MoE-augmented')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(range(len(baseline_accuracies)), baseline_accuracies, label='Baseline')
    plt.plot(range(len(moe_accuracies)), moe_accuracies, label='MoE-augmented')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

In [19]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [23]:
def tokenize_and_encode(examples):
    encoded_inputs = tokenizer(examples['tokens'], truncation=True, padding='longest', max_length=128)
    return {'input_ids': encoded_inputs['input_ids'], 'attention_mask': encoded_inputs['attention_mask']}

# Load datasets
conll_dataset = load_dataset('conll2003', split='train')
squad_dataset = load_dataset('squad', split='train')


conll_dataset = conll_dataset.map(tokenize_and_encode, batched=True)
conll_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'ner_tags'])

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]


ValueError: too many values to unpack (expected 2)

In [None]:
def main():

    # Define hyperparameters
    input_dim = 768  # Hidden size of the pre-trained model
    hidden_dim = 256
    num_layers = 2
    num_classes = 9  # Number of named entities in CoNLL-2003
    num_experts = 4
    batch_size = 32
    num_epochs = 10
    learning_rate = 1e-3
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize models
    baseline_model = BaselineModel(input_dim, hidden_dim, num_layers, num_classes).to(device)
    moe_model = MoEModel(input_dim, hidden_dim, num_layers, num_experts, num_classes).to(device)

    # Define dataloaders
    # conll_dataloader = DataLoader(conll_dataset, batch_size=batch_size, shuffle=True)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=128)
    conll_dataloader = DataLoader(conll_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
    squad_dataloader = DataLoader(squad_dataset, batch_size=batch_size, shuffle=True)

    # Define criteria and optimizers
    criterion = nn.CrossEntropyLoss()
    baseline_optimizer = torch.optim.Adam(baseline_model.parameters(), lr=learning_rate)
    moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=learning_rate)

    # Train and evaluate models
    baseline_losses, baseline_accuracies = [], []
    moe_losses, moe_accuracies = [], []
    for epoch in range(num_epochs):
        # Train baseline model
        baseline_model.train()
        for batch in conll_dataloader:
            input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            baseline_optimizer.zero_grad()
            outputs = baseline_model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            baseline_optimizer.step()

        # Evaluate baseline model
        baseline_loss, baseline_accuracy = evaluate(baseline_model, conll_dataloader, device, criterion)
        baseline_losses.append(baseline_loss)
        baseline_accuracies.append(baseline_accuracy)

        # Train MoE-augmented model
        moe_model.train()
        for batch in squad_dataloader:
            input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['answers']
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            moe_optimizer.zero_grad()
            outputs = moe_model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            moe_optimizer.step()

        # Evaluate MoE-augmented model
        moe_loss, moe_accuracy = evaluate(moe_model, squad_dataloader, device, criterion)
        moe_losses.append(moe_loss)
        moe_accuracies.append(moe_accuracy)

    # Plot results
    plot_results(baseline_losses, baseline_accuracies, moe_losses, moe_accuracies)

In [None]:
main()

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags']