In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import classification_report
from transformers import BertTokenizer, BertModel
from transformers.optimization import AdamW
from transformers import get_linear_schedule_with_warmup

In [3]:
# Define the MoE layer
class MoE(nn.Module):
    def __init__(self, input_dim, expert_dim, num_experts):
        super(MoE, self).__init__()
        self.input_dim = input_dim
        self.expert_dim = expert_dim
        self.num_experts = num_experts
        
        self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)])
        self.gating_network = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gating_scores = F.softmax(self.gating_network(x), dim=1)
        expert_outputs = [expert(x) for expert in self.experts]
        weighted_expert_outputs = [score.unsqueeze(2) * expert_output for score, expert_output in zip(gating_scores.split(1, dim=1), expert_outputs)]
        output = torch.sum(torch.stack(weighted_expert_outputs, dim=0), dim=0)
        return output

# Define the LSTM-MoE model
class LSTM_MoE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_experts):
        super(LSTM_MoE, self).__init__()
        self.lstm1 = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.moe = MoE(hidden_dim, hidden_dim, num_experts)
        self.lstm2 = nn.LSTM(hidden_dim, output_dim, num_layers=num_layers, batch_first=True)

    def forward(self, x):
        lstm1_output, _ = self.lstm1(x)
        moe_output = self.moe(lstm1_output)
        lstm2_output, _ = self.lstm2(moe_output)
        return lstm2_output


In [4]:
# Load CoNLL-2003 dataset
def load_conll_data(file_path):
    sentences = []
    labels = []
    with open(file_path, 'r') as file:
        words = []
        tags = []
        for line in file:
            line = line.strip()
            if line == '':
                sentences.append(words)
                labels.append(tags)
                words = []
                tags = []
            else:
                word, _, _, tag = line.split(' ')
                words.append(word)
                tags.append(tag)
    return sentences, labels

# Tokenize and encode sentences
def tokenize_and_encode(sentences, labels, tokenizer, max_length):
    tokenized_sentences = []
    encoded_labels = []
    for sent, labs in zip(sentences, labels):
        tokenized_sent = tokenizer.tokenize(' '.join(sent))
        tokenized_sentences.append(tokenized_sent)
        encoded_labels.append([label_map[label] for label in labs])
    input_ids = [tokenizer.convert_tokens_to_ids(tokens) for tokens in tokenized_sentences]
    input_ids = pad_sequence([torch.tensor(ids) for ids in input_ids], batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks = input_ids != tokenizer.pad_token_id
    return input_ids, attention_masks, encoded_labels

# Define evaluation function
def evaluate_model(model, dataloader):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 2)
            all_labels.extend(labels.cpu().numpy().flatten())
            all_preds.extend(preds.cpu().numpy().flatten())
    return all_labels, all_preds



In [6]:
# Load data
train_sentences, train_labels = load_conll_data('C:\Users\Vimarsh\Desktop\SAiDL-Assignment-2024\Q3_MoE\datasets\conll2003\train.txt')
val_sentences, val_labels = load_conll_data('C:\Users\Vimarsh\Desktop\SAiDL-Assignment-2024\Q3_MoE\datasets\conll2003\valid.txt')

# Tokenize and encode data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
input_ids_train, attention_masks_train, encoded_labels_train = tokenize_and_encode(train_sentences, train_labels, tokenizer, max_length=128)
input_ids_val, attention_masks_val, encoded_labels_val = tokenize_and_encode(val_sentences, val_labels, tokenizer, max_length=128)


SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \UXXXXXXXX escape (3370427490.py, line 2)