# A4: LSTMs and Transformers for Word Sense Disambiguation

by Nikolai Ilinykh, Adam Ek, and others.

The lab is an exploration and learning exercise to be done in a group and also in discussion with the teachers and other students.

Write all your answers and the code in the appropriate boxes below.


A problem with static distributional vectors is the difficulty of distinguishing between different *word senses*. We will continue our exploration of word vectors by considering *trainable vectors* or *word embeddings* for Word Sense Disambiguation (WSD). We will work with both LSTMs and transformer models, e.g. BERT. The purpose of the assignment is to learn use representations neural models in a downstream task of word sense disambiguation.


## Word Sense Disambiguation Task

The goal of word sense disambiguation is to train a model to find the sense of a word (homonyms of a word-form). For example, the word "bank" can mean "sloping land" or "financial institution". 

(a) "I deposited my money in the **bank**" (financial institution)

(b) "I swam from the river **bank**" (sloping land)

In case a) and b), we can determine the meaning of "bank" based on the *context*. To utilize context in a semantic model, we use *contextualized word representations*.

Previously, we worked with *static word representations*, i.e., the representation does not depend on the context. To illustrate, we can consider sentences (a) and (b), where the word **bank** would have the same static representation in both sentences, which means that it becomes difficult for us to predict its sense. What we want is to create representations that depend on the context, i.e., *contextualized embeddings*.

As we have discussed in the class, contextualized representations can come in the form of pre-training the model for some "general" task and then fine-tuning it for some downstream task. Here we will do the following:

(1) Train and test LSTM model directly for word sense disambiguation. We will learn contextualized representations within this model.

(2) Take BERT that was pre-trained on masked language modeling and next sentence prediction. Fine-tune it on our data and test it for the word sense disambiguation on the task dataset. The idea for you is to explore how pre-trained contextualized representations from BERT can be updated and used for the downstream task of word sense disambiguation.

Your overall task in this lab is to create a neural network model that can disambiguate the word sense of 30 different words.

In [1]:
# -*- coding: utf-8 -*-

In [2]:
# first we import some packages that we need

# here add any package that you will need later
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
#import spacy #tokenize by craft, or `tokenizer = spacy.load("en_core_web_sm")`(cost more time...)
#import torchtext # User warning: Torchtext is deprecated
#from torchtext.data import get_tokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict

# our hyperparameters (add more when/if you need them)
device = torch.device('cuda:0')


In [3]:
#!pip install spacy # install necessary packages before import

In [4]:
#!python -m spacy download en_core_web_sm

# 1. Working with Data

A central part of any machine learning system is the data we're working with.

In this section, we will split the data (the dataset is in `wsd_data.txt`) into a training set and a test set.


## Data

The dataset we will use contains different word senses for 30 different words. The data is organized as follows (values separated by tabs), where each line is a separate item in the dataset:

- Column 1: word-sense, e.g., keep%2:42:07::
- Column 2: word-form, e.g., keep.v
- Column 3: index of word, e.g., 15
- Column 4: white-space tokenized context, e.g., Action by the Committee In pursuance of its mandate , the Committee will continue to keep under review the situation relating to the question of Palestine and participate in relevant meetings of the General Assembly and the Security Council . The Committee will also continue to monitor the situation on the ground and draw the attention of the international community to urgent developments in the Occupied Palestinian Territory , including East Jerusalem , requiring international action .


**[Understand the data]**

It is a snippet from WordNet: 

`word-sense`: Sense Key, lemma%pos:sense_number:lexicographer_file_number

`index` represents the position of the word in the context.

e.g., `S: (n) support (support%1:21:00::), keep (keep%1:21:00::), livelihood (livelihood%1:21:00::), living (living%1:21:00::), bread and butter (bread_and_butter%1:21:00::), sustenance (sustenance%1:21:00::) (the financial means whereby one lives) "each child was expected to pay for their keep"; "he applied to the state for support"; "he could no longer earn his own livelihood"`

POS: noun(n)

Sysets: support, keep, livelihood, living, bread and butter, sustenance

Definition: the financial means wherebyone lives


Example sentences: "each child was expected to pay for their keep"; "he applied to the state for support"; "he could no longer earn his own livelihood"

ref: https://wordnet.princeton.edu/

### Splitting the Data

Your first task is to separate the data into a *training set* and a *test set*.

The training set should contain 80% of the examples, and the test set the remaining 20%.

The examples for the test/training set should be selected **randomly**.

Save each dataset into a .csv file for loading later.

**[2 marks]**

In [5]:
def data_split(dataset_path):
    """
    Split the dataset into a training set and a test set, and save them into separate .csv files.

    Parameters:
        dataset_path (str): The file path to the dataset.

    Returns:
        pd.DataFrame, pd.DataFrame: A tuple containing the training set and the test set as pandas DataFrames.
    """
    # Load dataset
    data = pd.read_csv(dataset_path, sep='\t', header=None)
    data.columns = ['word_sense', 'word_form', 'index', 'context']

    # Shuffle dataset
    data_shuffled = data.sample(frac=1, random_state=42)

    # Calculate number of examples for training and test sets
    num_examples = len(data_shuffled)
    train_size = int(0.8 * num_examples)

    # Split dataset
    train_split = data_shuffled.iloc[:train_size]
    test_split = data_shuffled.iloc[train_size:]

    # Save splits to CSV files
    train_split.to_csv('train_split.csv', index=False)
    test_split.to_csv('test_split.csv', index=False)
    
    return train_split, test_split

In [6]:
train_set, test_set = data_split('wsd_data.txt')

### Creating a Baseline

Your second task is to create a *baseline* for the task.

A baseline is a "reality check" for a model. Given a very simple heuristic/algorithmic/model solution to the problem, can our neural network perform better than this? Baselines are important as they give us a point of comparison for the actual models. They are commonly used in NLP. Sometimes baseline models are not simple models but previous state-of-the-art.

In this exercise, we will have a simple baseline model that is the "most common sense" (MCS) baseline. For each word form, find the most commonly assigned sense to the word and label a word with that sense. In a fictional dataset, "bank" has two senses: "financial institution," which occurs 5 times, and "side of the river," which occurs 3 times. Thus, all 8 occurrences of "bank" are labeled "financial institution," yielding an MCS accuracy of 5/8 = 62.5%. If a model obtains a higher score than this, we can conclude that the model *at least* is better than selecting the most frequent word sense.

Your task is to write the code for this baseline, train, and test it. The baseline has the knowledge about labels and their frequency only from the train data. You evaluate it on the test data by comparing the ground-truth sense with the one that the model predicts. A good "dumb" baseline in this case is the one that performs quite badly. Expect the model to perform around 0.30 in terms of accuracy. You should use accuracy as your main metric; you can also compute the F1-score.

**[2 marks]**


In [11]:
def mcs_baseline(train_data, test_data):
    """
    Most Common Sense (MCS) baseline for word sense disambiguation.

    Parameters:
        train_data (pd.DataFrame): The DataFrame containing the training dataset with columns: 
                                   'word_sense', 'word_form', 'index', 'context'.
        test_data (pd.DataFrame): The DataFrame containing the test dataset with the same columns.

    Returns:
        float: Accuracy of the baseline model.
        dict: Per-word-form accuracy (weighted).
    """
    # Calculate the most common sense for each word form in the training data
    most_common_senses = train_data.groupby('word_form')['word_sense'].agg(lambda x: x.value_counts().idxmax())

    # Predict the most common sense for each example in the test data
    predictions = test_data['word_form'].map(most_common_senses)

    # Calculate overall accuracy
    accuracy = (predictions == test_data['word_sense']).mean()

    # Calculate per-word-form accuracy (weighted)
    per_word_form_accuracy = {}
    for word_form, predicted_sense in zip(test_data['word_form'], predictions):
        actual_sense = test_data[test_data['word_form'] == word_form]['word_sense'].iloc[0]
        num_senses = len(train_data[train_data['word_form'] == word_form]['word_sense'].unique())
        per_word_form_accuracy[word_form] = {
        'accuracy': int(predicted_sense == actual_sense) / num_senses,
        'sense_count': num_senses
    }

    return accuracy, per_word_form_accuracy


In [12]:
baseline_accuracy, baseline_per_word_form_accuracy = mcs_baseline(train_set, test_set)
print("MCS baseline accuracy:", baseline_accuracy)

MCS baseline accuracy: 0.3193293885601578


### Creating Data Iterators

To train a neural network, we first need to prepare the data. This involves converting words (and labels) to a number and organizing the data into batches. We also want the ability to shuffle the examples such that they appear in a random order.

Your task is to create a dataloader for the training and test set you created previously.

You are encouraged to adjust your own dataloader you built for previous assignments. Some things to take into account:

1. Tokenize inputs, keep a dictionary of word-to-IDs and IDs-to-words (vocabulary), fix paddings. You might need to consider doing these for each of the four fields in the dataset.
2. Your dataloader probably has a function to process data. Process each column in the dataset.
3. You might want to clean the data a bit. For example, the first column has some symbols, which might be unnecessary. It is up to you whether you want to remove them and clean this column or keep labels the way they are. In any case, you must provide an explanation of your decision and how you think it will affect the performance of your model. Data and its preprocessing matters, so motivate your decisions.
4. Organize your dataset into batches and shuffle them. You should have something akin to data iterators so that your model can take them.

Implement the dataloader and perform necessary preprocessings.

[**2 marks**]

In [9]:
class CustomDataset(Dataset):
    """
    A custom PyTorch dataset for word sense disambiguation.

    Attributes:
        data (DataFrame): The DataFrame containing the dataset.
        word_to_idx (dict): A dictionary mapping words to their indices.
        label_encoder (LabelEncoder): An instance of sklearn's LabelEncoder for encoding word senses.
        word_form_sense_counts (dict): A dictionary to track sense counts for each word form.

    Args:
        data (DataFrame): The DataFrame containing the dataset.
        word_to_idx (dict): A dictionary mapping words to their indices.
        label_encoder (LabelEncoder): An instance of sklearn's LabelEncoder for encoding word senses.
    """
    def __init__(self, data, word_to_idx, label_encoder):
        self.data = data
        # print("data sample:",data.head(10))
        self.word_to_idx = word_to_idx
        self.label_encoder = label_encoder
        self.word_form_sense_counts = defaultdict(set)  # Initialize as defaultdict to handle unseen word forms
        
        # Count senses for each word form and calculate frequencies
        for idx in range(len(data)):
            word_form = data.iloc[idx]['word_form']
            word_sense = data.iloc[idx]['word_sense']
            if word_form not in self.word_form_sense_counts:
                self.word_form_sense_counts[word_form] = set()  # Use a set to avoid duplicates
            self.word_form_sense_counts[word_form].add(word_sense)


    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.data)
        

    def __getitem__(self, idx):
        """
        Get a sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.`idx` is passed during calling Dataloader.

        Returns:
            dict: A dictionary containing the sample data with keys 'word_form', 'index', 'context', 'word_sense' and 'word_sense_counts'.
        """
        row = self.data.iloc[idx] 
        
        word_form = row['word_form']
        index = row['index']
        context = row['context']
        word_sense = row['word_sense']
        
        # Tokenize and convert to indices
        tokenized_context = []
        for token in context.lower().split():
            if token in self.word_to_idx:
                tokenized_context.append(self.word_to_idx[token])
            else:
                tokenized_context.append(self.word_to_idx['<UNK>'])

        target_word = word_form.split('.')[0]
        # print("target_word:",target_word)
        # print("context:",context)
        mapped_target_word = self.word_to_idx[target_word] if target_word in self.word_to_idx else self.word_to_idx['<UNK>']
        
        # Encode word sense
        encoded_word_sense = self.label_encoder.transform([word_sense])[0]

        # Count number of senses for this word form
        word_sense_counts = len(self.word_form_sense_counts[word_form])

        return {
            'word_form': mapped_target_word,
            'index': index,
            'context': tokenized_context,
            'word_sense': encoded_word_sense,
            'word_sense_counts': word_sense_counts
        }


def collate_batch(batch):
    """
    Collate function for processing each batch in the DataLoader.

    Args:
        batch (list): A list of samples, where each sample is a dictionary containing keys 'context', 'index', 'word_form', 'word_sense' and 'word_sense_counts'.

    Returns:
        dict: A dictionary containing the batch data with keys 'context', 'target_index', 'word_form', 'word_sense' and 'word_sense_counts'.
            'context' is a padded tensor representing the contexts of the samples, 'target_index' is a tensor containing the indices of the target words, 
            'word_form' is a tensor containing the indices of the target words, 'word_sense' is a tensor containing the encoded word senses,
            and 'word_sense_counts' is a tensor containing the number of senses for each word form.
    """
    contexts = [torch.tensor(item['context'], dtype=torch.long) for item in batch]
    word_indice = torch.tensor([item['index'] for item in batch], dtype=torch.long)
    word_forms = torch.tensor([item['word_form'] for item in batch], dtype=torch.long)
    word_senses = torch.tensor([item['word_sense'] for item in batch], dtype=torch.long)
    word_sense_counts = torch.tensor([item['word_sense_counts'] for item in batch], dtype=torch.long)  # for later evaluation

    contexts_padded = pad_sequence(contexts, batch_first=True, padding_value=1)

    return {
        'context': contexts_padded,
        'target_index': word_indice,
        'word_form': word_forms,
        'word_sense': word_senses,
        'word_sense_counts': word_sense_counts
    }

def build_word_to_idx(data):
    """
    Build a dictionary mapping words to their corresponding indices based on the given dataset.

    Args:
        data (DataFrame): DataFrame containing the dataset, with a column named 'context' containing text data.

    Returns:
        dict: A dictionary mapping words to their indices. Special tokens: <PAD> for padding the contexts, <UNK> for unseen words in test set.
    """
    word_to_idx = defaultdict(lambda: len(word_to_idx))  # Use defaultdict for automatic indexing

    # Add special tokens
    word_to_idx["<PAD>"]  # Ensure padding token is in the dictionary

    for text in data['context']:
        for token in text.lower().split():  # Lowercase and split on whitespace
            word_to_idx[token]

    # Convert defaultdict to regular dict
    word_to_idx = dict(word_to_idx)
    
    return word_to_idx
    
def data_load(batch_size=32, shuffle=True):
    """
    Load and preprocess the dataset, and create data loaders for training and testing.

    Args:
        batch_size (int, optional): Batch size for training and testing. Defaults to 32.
        shuffle (bool, optional): Whether to shuffle the training data. Defaults to True.

    Returns:
        tuple: A tuple containing the training data loader, testing data loader,
               word-to-idx dictionary, and the output dimension.
    """
    # Load dataset
    print("load split dataset...")
    train_data = pd.read_csv('train_split.csv', sep=',', header=0)  # pay attention to parameters here
    test_data = pd.read_csv('test_split.csv', sep=',', header=0)

    train_data.columns = ['word_sense', 'word_form', 'index', 'context']
    test_data.columns = ['word_sense', 'word_form', 'index', 'context']
    print("length of train set:",len(train_data))
    print("length of test set:",len(test_data))

    # Build word_to_idx dictionary
    print("build word_to_idx dictionary...")
    word_to_idx = build_word_to_idx(train_data)
    word_to_idx['<UNK>'] = len(word_to_idx)
    print("vocab_size:", len(word_to_idx))

    # Encode word senses
    print("encode word senses...")
    combined_data = pd.concat([train_data, test_data]) 
    label_encoder = LabelEncoder()
    label_encoder.fit(combined_data['word_sense'])
    output_dim = len(label_encoder.classes_)
    print("number of distinguished word senses(output_dim):", output_dim)

    # Preprocess data
    print("preprocess data...")
    train_dataset = CustomDataset(train_data, word_to_idx, label_encoder)
    test_dataset = CustomDataset(test_data, word_to_idx, label_encoder)

    # Load batches
    print("load batches...")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

    print("loaded!")

    return train_loader, test_loader, word_to_idx, output_dim



In [10]:
# train_data = pd.read_csv('train_split.csv', sep=',', header=0)
# train_data.columns = ['word_sense', 'word_form', 'index', 'context']
# # summarize by word_sense
# word_sense_summary = train_data['word_sense'].value_counts()
# print(word_sense_summary)

# 2.1 LSTM for Word Sense Disambiguation

In this section, we will train an LSTM model to predict word senses based on *contextualized representations*.

You can read more about LSTMs [here](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).


### Model

We will use a **bidirectional** Long Short-Term Memory (LSTM) network to create a representation for the sentences and a **linear** classifier to predict the sense of each word.

As we discussed in the lecture, bidirectional LSTM is using **two** hidden states: one that goes in the left-to-right direction, and another one that goes in the right-to-left direction. PyTorch documentation on LSTMs can be found [here](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html). It says that if the bidirectional parameter is set to True, then "h_n will contain a concatenation of the final forward and reverse hidden states, respectively." Keep it in mind because you will have to ensure that your linear layer for prediction takes input of that size.

When we initialize the model, we need a few things:

1) An embedding layer: a dictionary from which we can obtain word embeddings
2) A LSTM-module to obtain contextual representations
3) A classifier that computes scores for each word-sense given *some* input

The general procedure is the following:

1) For each word in the sentence, obtain word embeddings
2) Run the embedded sentences through the LSTM
3) Select the appropriate hidden state
4) Predict the word-sense 

**Suggestion for efficiency:** *Use a low dimensionality (32) for word embeddings and the LSTM when developing and testing the code, then scale up when running the full training/tests*

Your tasks will be to create **two different models** (both follow the two outlines described above).

-----

Your first model should make a prediction from the LSTM's representation of the target word.

In particular, you run your LSTM on the context in which the target word is used. LSTM will produce a sequence of hidden states. Each hidden state corresponds to a single word from the input context. For example, you should be able to get 37 hidden states for a context that has 37 words/elements in it. Next, take the LSTM's representation of the target word. For example, it can be hidden state number 5, because the fifth word in your context is the target word that you want to predict the meaning for. This target's word representation is the input to your linear layer that makes the final prediction.

**[5 marks]**

In [11]:
class WSDModel_approach1(nn.Module):
    """
    Word Sense Disambiguation Model using LSTM with a bidirectional architecture.
    You should make a prediction from the LSTM's representation of the target word.

    Args:
        vocab_size (int): Size of the vocabulary.
        embedding_dim (int): Dimensionality of word embeddings.
        hidden_dim (int): Dimensionality of the hidden state of the LSTM.
        output_dim (int): Dimensionality of the output.

    Attributes:
        embeddings (nn.Embedding): Embedding layer.
        rnn (nn.LSTM): LSTM module.
        classifier (nn.Linear): Linear classifier.

    Methods:
        forward(context, target_index): Forward pass through the model.

    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(WSDModel_approach1, self).__init__()
        
        # Embedding layer
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # LSTM module
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True) # `bidirectional=True` => output.shape: (num_layers * num_directions, batch_size, hidden_size)
        
        # Linear classifier
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)   # why `*2`: concatenating forward and backward hidden states
    
    def forward(self, context, target_index):
        """
        Forward pass of the WSDModel_approach1.
    
        Args:
            context (torch.Tensor): Tensor of shape (batch_size, seq_length) containing the input word indices.
            target_index (torch.Tensor): Tensor of shape (batch_size,) containing the indices of the target words within the context.
    
        Returns:
            torch.Tensor: Tensor of shape (batch_size, output_dim) containing the model's predictions for each target word.
        """
        
        # Obtain word embeddings
        embedded_context = self.embeddings(context)
        
        # Run the embedded context through the LSTM
        output, (_, _) = self.rnn(embedded_context) # output, (hidden,cell)
        
        # Select the appropriate hidden state (representation of the target word)
        target_hidden = output[torch.arange(output.size(0)), target_index, :]  # get the target hidden state
        
        # Predict the word-sense
        predictions = self.classifier(target_hidden)
        
        return predictions


Your second model should make a prediction from the final hidden state of your LSTM.

In particular, do the same first steps as in the first approach. But then to make a prediction with your linear layer, you will need to take the last hidden state that your LSTM produces for the whole sequence.

**[5 marks]**

In [12]:
class WSDModel_approach2(nn.Module):
    """
    Word Sense Disambiguation Model using LSTM with a bidirectional architecture.
    You should make a prediction from the final hidden state of your LSTM.

    Args:
        vocab_size (int): Size of the vocabulary.
        embedding_dim (int): Dimensionality of word embeddings.
        hidden_dim (int): Dimensionality of the hidden state of the LSTM.
        output_dim (int): Dimensionality of the output.

    Attributes:
        embeddings (nn.Embedding): Embedding layer.
        rnn (nn.LSTM): LSTM module.
        classifier (nn.Linear): Linear classifier.

    Methods:
        forward(context): Forward pass through the model.

    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(WSDModel_approach2, self).__init__()
        
        # Embedding layer
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # LSTM module
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Linear classifier
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)  # why `*2`: concatenating forward and backward hidden states
    
    def forward(self, batch):
        """
        Forward pass of the WSDModel_approach2.
    
        Args:
            batch (torch.Tensor): Tensor of shape (batch_size, seq_length) containing the input word indices.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, output_dim) containing the model's predictions for each final hidden state.
        """
        # Extract relevant data from the batch
        context = batch # here `batch` is `inputs`

        # Obtain word embeddings
        embedded_context = self.embeddings(context)
        
        # Run the embedded context through the LSTM
        _, (hidden, _) = self.rnn(embedded_context) 
        final_hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) # take the last hidden state 
        # print(final_hidden.shape)

        # Predict the word-sense
        predictions = self.classifier(final_hidden)
        
        return predictions


### Training and Testing the Model

Now we are ready to train and test our model. What we need now is a loss function, an optimizer, and our data. 

- First, create the loss function and the optimizer.
- Next, iterate over the number of epochs (i.e., how many times we let the model see our data). 
- For each epoch, iterate over the dataset to obtain batches. Use the batch as input to the model, and let the model output scores for the different word senses.
- For each model output, calculate the loss (and print the loss) on the output and update the model parameters.
- Reset the gradients and repeat.
- After all epochs are done, test your trained model on the test set and calculate the total and per-word-form accuracy of your model.

Implement the training and testing of the model.

**[4 marks]**

**Suggestion for efficiency:** *When developing your model, try training and testing the model on one or two batches (for each epoch) of data to make sure everything works! It's very annoying if you train for N epochs to find out that something went wrong when testing the model, or to find that something goes wrong when moving from epoch 0 to epoch 1.*

Do not forget to save your best models as .pickle files. The results should be reproducible for us to evaluate your models.


In [13]:
learning_rate = 0.0005
epochs = 5
batch_size = 8
embedding_dim = 256 # 32 - use a low dimensionality when developing and testing the code
hidden_dim = 256

In [15]:
# load data
train_loader, test_loader, word_to_idx, output_dim = data_load(batch_size=batch_size, shuffle=True)

load split dataset...
length of train set: 60839
length of test set: 15210
build word_to_idx dictionary...
vocab_size: 70479
encode word senses...
number of distinguished word senses(output_dim): 222
preprocess data...
load batches...
loaded!


In [16]:
# train_iter = iter(train_loader)
# train_batch = next(train_iter)

# print("train_loader sample:",train_batch) # the first batch (8 sents)

In [17]:
def test_model_1(model, test_loader):
    """
    Test the WSDModel_approach1 on the test data.

    Args:
        model (WSDModel_approach1): The trained WSD model.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        tuple: (float, dict) Accuracy of the model on the test dataset, 
               and per-word-form accuracy

    """
    model.eval()
    correct = 0
    total = 0

    # Dictionary to store correct and total predictions for each word form
    word_form_stats = {}

    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['context']
            target_indices = batch['target_index']
            labels = batch['word_sense']
            word_forms = batch['word_form']
            word_sense_counts = batch['word_sense_counts']  # Number of senses for each word form
            
            inputs = inputs.to(device)
            target_indices = target_indices.to(device)
            labels = labels.to(device)
            word_forms = word_forms.to(device)
            word_sense_counts = word_sense_counts.to(device)
      
            outputs = model(inputs, target_indices)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update word form stats
            for word_form, label, prediction, sense_count in zip(word_forms, labels, predicted, word_sense_counts):
                word_form = word_form.item()  # convert tensor to scalar
                sense_count = sense_count.item()
                if word_form not in word_form_stats:
                    word_form_stats[word_form] = {'correct': 0, 'total': 0, 'sense_count': sense_count}
                word_form_stats[word_form]['total'] += 1
                if label == prediction:
                    word_form_stats[word_form]['correct'] += 1

    accuracy = 100 * correct / total

    # Calculate per-word-form accuracy with weighting
    per_word_form_accuracy = {}
    for word_form, stats in word_form_stats.items():
        raw_accuracy = stats['correct'] / stats['total']
        weighted_accuracy = raw_accuracy / stats['sense_count']
        per_word_form_accuracy[word_form] = {
        'accuracy': 100 * weighted_accuracy,
        'sense_count': stats['sense_count']
    }
    
    return accuracy, per_word_form_accuracy


In [18]:
print("build model...")
loss_function = nn.CrossEntropyLoss()
model_1 = WSDModel_approach1(len(word_to_idx), embedding_dim, hidden_dim, output_dim)
model_1.to(device)
optimizer_1 = optim.Adam(model_1.parameters(), lr=learning_rate)

#best_accuracy_1 = 0.0 # set the threshold value when best_model has existed
best_accuracy_1 = 73.41
best_model_path_1 = 'best_model_1.pickle'

print("training model 1...")
for epoch in range(epochs):
    total_loss = 0.0
    
    # train model
    model_1.train()  # train mode
    for batch in train_loader:

        # reset gradient
        optimizer_1.zero_grad()
        
        # forward
        inputs = batch['context']
        target_indices = batch['target_index']
        labels = batch['word_sense']

        inputs = inputs.to(device)
        target_indices = target_indices.to(device)
        labels = labels.to(device)
        
        outputs = model_1(inputs, target_indices)

        # calculate loss
        loss = loss_function(outputs, labels)
        
        # Backpropagation
        loss.backward()
        
        # update parameters
        optimizer_1.step()
        
        # accumulate losses
        total_loss += loss.item()
    
    # calculate average loss
    avg_loss = total_loss / len(train_loader)
    print(f"epoch {epoch + 1}, average loss: {avg_loss:.4f}")

    # monitor the total accuracy during training
    accuracy, _ = test_model_1(model_1, test_loader)
    print(f"\taccuracy on test set: {accuracy:.2f}%")
    if accuracy > best_accuracy_1:
        best_accuracy_1 = accuracy
        torch.save(model_1.state_dict(), best_model_path_1)
        print("\tBest model saved!")

print("training finished!")


build model...
training model 1...
epoch 1, average loss: 1.1483
	accuracy on test set: 70.30%
epoch 2, average loss: 0.6502
	accuracy on test set: 71.89%
epoch 3, average loss: 0.3751
	accuracy on test set: 72.81%
epoch 4, average loss: 0.1753
	accuracy on test set: 73.53%
	Best model saved!
epoch 5, average loss: 0.0881
	accuracy on test set: 73.16%
training finished!


In [19]:
# Load the best model and evaluate it on the test set
best_model_1 = WSDModel_approach1(len(word_to_idx), embedding_dim, hidden_dim, output_dim)
best_model_1.load_state_dict(torch.load(best_model_path_1))
best_model_1.to(device)
best_model_1.eval()  # Set the model to evaluation mode

# generate idx_to_word dictionary
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

with torch.no_grad():
    final_accuracy, per_word_idx_accuracy = test_model_1(best_model_1, test_loader)
    print(f"final test accuracy for model 1: {final_accuracy:.2f}%")

    per_word_form_accuracy_1 = {}
    for idx, stats in per_word_idx_accuracy.items():
        accuracy,sense_count = stats.values()
        word = idx_to_word[idx] # here per word maps one word-form
        per_word_form_accuracy_1[word] = {
        'accuracy': accuracy,
        'sense_count': sense_count
    }

    # sort by weighted accuracy 
    sorted_per_word_form_accuracy_1 = dict(sorted(per_word_form_accuracy_1.items(), key=lambda item: item[1]['accuracy'], reverse=True))

# print
output_1 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_per_word_form_accuracy_1.items()])
print(f"per_word_form_accuracy(sorted by weighted accuracy): {output_1}") 
    
## record of final accuracy on test set
# Note here: The same combination of hyperparameters does not always achieve the same accuracy! => To save the best model each time is needed...
# learning_rate = 0.001, epochs = 3, batch_size = 8, embedding_dim = 32, hidden_dim = 64 => 68.65%
# learning_rate = 0.001, epochs = 5, batch_size = 8, embedding_dim = 128, hidden_dim = 64 => 70.93%
# learning_rate = 0.0005, epochs = 4, batch_size = 8, embedding_dim = 256, hidden_dim = 256 => 73.53%

final test accuracy for model 1: 73.53%
per_word_form_accuracy(sorted by weighted accuracy): 'bad': 19.98%, 4; 'professional': 16.53%, 5; 'major': 16.15%, 4; 'common': 16.08%, 4; 'active': 15.68%, 5; 'order': 15.66%, 5; 'critical': 15.14%, 5; 'time': 14.49%, 5; 'positive': 14.32%, 5; 'national': 12.72%, 6; 'security': 12.31%, 7; 'physical': 11.71%, 6; 'position': 11.01%, 6; 'place': 11.01%, 7; 'force': 10.28%, 8; 'point': 9.70%, 8; 'extend': 8.99%, 7; 'line': 8.80%, 11; 'regular': 8.62%, 8; 'life': 8.23%, 9; 'bring': 7.85%, 8; 'case': 7.79%, 8; 'see': 7.42%, 11; 'serve': 7.40%, 9; 'keep': 7.23%, 11; 'find': 6.72%, 10; 'lead': 6.71%, 8; 'hold': 6.29%, 11; 'follow': 5.93%, 11; 'build': 4.88%, 10


In [20]:
def test_model_2(model, test_loader):
    """
    Test the WSDModel_approach2 on the test data.

    Args:
        model (WSDModel_approach2): The trained WSD model.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        tuple: (float, dict) Accuracy of the model on the test dataset, 
               and per-word-form accuracy

    """
    model.eval()
    correct = 0
    total = 0

    # Dictionary to store correct and total predictions for each word form
    word_form_stats = {}

    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['context']
            labels = batch['word_sense']
            word_forms = batch['word_form']
            word_sense_counts = batch['word_sense_counts']  # Number of senses for each word form

            inputs = inputs.to(device)
            labels = labels.to(device)
            word_forms = word_forms.to(device)
            word_sense_counts = word_sense_counts.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             # Update word form stats
            for word_form, label, prediction, sense_count in zip(word_forms, labels, predicted, word_sense_counts):
                word_form = word_form.item()  # convert tensor to scalar
                sense_count = sense_count.item()
                if word_form not in word_form_stats:
                    word_form_stats[word_form] = {'correct': 0, 'total': 0, 'sense_count': sense_count}
                word_form_stats[word_form]['total'] += 1
                if label == prediction:
                    word_form_stats[word_form]['correct'] += 1

    accuracy = 100 * correct / total

    # Calculate per-word-form accuracy with weighting
    per_word_form_accuracy = {}
    for word_form, stats in word_form_stats.items():
        raw_accuracy = stats['correct'] / stats['total']
        weighted_accuracy = raw_accuracy / stats['sense_count']
        per_word_form_accuracy[word_form] = {
        'accuracy': 100 * weighted_accuracy,
        'sense_count': stats['sense_count']
    }
        
    return accuracy, per_word_form_accuracy

In [21]:
# try model_2
model_2 = WSDModel_approach2(len(word_to_idx), embedding_dim, hidden_dim, output_dim)
model_2.to(device)
optimizer_2 = optim.Adam(model_2.parameters(), lr=learning_rate)

best_accuracy_2 = 58.42 # set the threshold value when best_model has existed
best_model_path_2 = 'best_model_2.pickle'

print("training model 2...")
for epoch in range(epochs):
    total_loss = 0.0
    
    # train model
    model_2.train()  # train mode
    for batch in train_loader:
        # reset gradient
        optimizer_2.zero_grad()
        
        # forward
        inputs = batch['context']
        labels = batch['word_sense']

        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model_2(inputs)

        # calculate loss
        loss = loss_function(outputs, labels)
        
        # Backpropagation
        loss.backward()
        
        # update parameters
        optimizer_2.step()
        
        # accumulate losses
        total_loss += loss.item()
    
    # calculate average loss
    avg_loss = total_loss / len(train_loader)
    print(f"epoch {epoch + 1}, average loss: {avg_loss:.4f}")

    # monitor the total accuracy during training
    accuracy, _ = test_model_2(model_2, test_loader)
    print(f"\taccuracy on test set: {accuracy:.2f}%")
    if accuracy > best_accuracy_2:
        best_accuracy_2 = accuracy
        torch.save(model_2.state_dict(), best_model_path_2)
        print("\tBest model saved!")

print("training finished!")


training model 2...
epoch 1, average loss: 4.7453
	accuracy on test set: 15.47%
epoch 2, average loss: 3.8658
	accuracy on test set: 29.11%
epoch 3, average loss: 2.3848
	accuracy on test set: 47.34%
epoch 4, average loss: 1.3691
	accuracy on test set: 53.99%
epoch 5, average loss: 0.9272
	accuracy on test set: 57.58%
training finished!


In [22]:
# Load the best model and evaluate it on the test set
best_model_2 = WSDModel_approach2(len(word_to_idx), embedding_dim, hidden_dim, output_dim)
best_model_2.load_state_dict(torch.load(best_model_path_2))
best_model_2.to(device)
best_model_2.eval()  # Set the model to evaluation mode

with torch.no_grad():
    final_accuracy, per_word_idx_accuracy = test_model_2(best_model_2, test_loader)
    print(f"final test accuracy for model 2: {final_accuracy:.2f}%")

    per_word_form_accuracy_2 = {}
    for idx, stats in per_word_idx_accuracy.items():
        accuracy,sense_count = stats.values()
        word = idx_to_word[idx] # here per word maps one word-form
        per_word_form_accuracy_2[word] = {
        'accuracy': accuracy,
        'sense_count': sense_count
    }

    # sort by weighted accuracy 
    sorted_per_word_form_accuracy_2 = dict(sorted(per_word_form_accuracy_2.items(), key=lambda item: item[1]['accuracy'], reverse=True))

# print
output_2 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_per_word_form_accuracy_2.items()])
print(f"per_word_form_accuracy(sorted by weighted accuracy): {output_2}") 

## record of final accuracy on test set
# adding epochs here can bring obvious improvement => why? learning more to counteract the negative impacts brought by the "crude" prediction target...
# learning_rate = 0.001, epochs = 3, batch_size = 8, embedding_dim = 32, hidden_dim = 64 => 16.96%
# learning_rate = 0.001, epochs = 8, batch_size = 8, embedding_dim = 128, hidden_dim = 64 => 53.10% (a great improvement!)
# learning_rate = 0.0005, epochs = 7, batch_size = 8, embedding_dim = 256, hidden_dim = 256 => 58.42%

final test accuracy for model 2: 58.42%
per_word_form_accuracy(sorted by weighted accuracy): 'bad': 17.52%, 4; 'professional': 13.27%, 5; 'order': 12.70%, 5; 'positive': 11.91%, 5; 'active': 11.58%, 5; 'common': 11.54%, 4; 'critical': 9.66%, 5; 'major': 9.34%, 4; 'physical': 9.07%, 6; 'position': 8.95%, 6; 'time': 8.65%, 5; 'security': 8.62%, 7; 'line': 8.44%, 11; 'point': 8.42%, 8; 'force': 8.37%, 8; 'place': 8.32%, 7; 'national': 7.82%, 6; 'extend': 7.37%, 7; 'see': 6.91%, 11; 'life': 6.66%, 9; 'keep': 6.60%, 11; 'regular': 6.23%, 8; 'serve': 5.97%, 9; 'lead': 5.17%, 8; 'bring': 5.16%, 8; 'find': 4.96%, 10; 'case': 4.82%, 8; 'hold': 4.67%, 11; 'follow': 4.31%, 11; 'build': 2.90%, 10


In [23]:
## try to figure out the optimal combination of hyperparameters for approach 1 (considering time cost, only run batch_size=8 completely)
# define model training and evaluation functions
# def train_model(model, train_loader, optimizer, loss_function, device):
#     model.train()
#     total_loss = 0.0

#     for batch in train_loader:
#         inputs = batch['context']
#         target_indices = batch['target_index']
#         labels = batch['word_sense']

#         inputs = inputs.to(device)
#         target_indices = target_indices.to(device)
#         labels = labels.to(device)

#         optimizer.zero_grad()
#         outputs = model(inputs, target_indices)
#         loss = loss_function(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()

#     avg_loss = total_loss / len(train_loader)
#     return avg_loss

# # evaluation: test_model_1()

# # define grid search function
# def grid_search(params, train_loader, test_loader, device):

#     for embedding_dim in params['embedding_dim']:
#         for hidden_dim in params['hidden_dim']:
#             for learning_rate in params['learning_rate']:
#                 print(f'Training with embedding_dim={embedding_dim}, hidden_dim={hidden_dim}, learning_rate={learning_rate}...')

#                 model = WSDModel_approach1(vocab_size=len(word_to_idx), embedding_dim=embedding_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
#                 optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#                 loss_function = nn.CrossEntropyLoss()

#                 for epoch in range(params['epochs']):
#                     train_loss = train_model(model, train_loader, optimizer, loss_function, device)
#                     print(f'epoch {epoch + 1}, loss: {train_loss}')

#                 accuracy = test_model_1(model, test_loader)
#                 print(f'\tAccuracy: {accuracy:.2f}%')

#                 if round(accuracy, 2) > round(best_accuracy, 2):
#                     print(f'\tNew best accuracy: {accuracy:.2f}% (previous: {best_accuracy:.2f}%). Saving model...')
#                     torch.save(model.state_dict(), best_model_path_1)
#                     print("\tBest model saved!")
#                     best_accuracy = accuracy
#                     best_params = {
#                         'embedding_dim': embedding_dim,
#                         'hidden_dim': hidden_dim,
#                         'learning_rate': learning_rate,
#                         'accuracy': best_accuracy
#                     }

#     return best_params


In [24]:
# # define parameter grid
# params = {
#     'batch_size': [8,16,32,64],
#     'embedding_dim': [64, 128, 256],
#     'hidden_dim': [64, 128, 256],
#     'learning_rate': [0.001, 0.0005, 0.0001],
#     'epochs': 5
# }

# best_params = None
# #best_accuracy = 0.0
# best_accuracy = 70.93 # Initialization(pay attention to the scale!)

# # load data for each batch_size
# for batch_size in params['batch_size']:
#     print(f'Loading data with batch_size = {batch_size}...')
#     train_loader, test_loader, word_to_idx, output_dim = data_load(batch_size = batch_size, shuffle=True)

#     # run grid search
#     current_best_params = grid_search(params, train_loader, test_loader, device)
#     print(f'\tBest Parameters for batch_size={batch_size}: {current_best_params}')

#     # Update global best parameters if current batch size yields better accuracy
#     if current_best_params and current_best_params['accuracy'] > best_accuracy:
#         best_accuracy = current_best_params['accuracy']
#         best_params = current_best_params

# # After all batch sizes are processed, print the overall best parameters
# print(f'Overall Best Parameters: {best_params}')

# ## record of grid search
# # Best Parameters for batch_size=8: {'embedding_dim': 256, 'hidden_dim': 256, 'learning_rate': 0.0005, 'accuracy': 73.64891518737673}

# 2.2 Fine-tuning and Testing BERT for Word Sense Disambiguation

In this section of the lab, you'll try out the transformer, specifically the BERT model. For this, we'll use the Hugging Face library ([https://huggingface.co/](https://huggingface.co/)).

You can find the documentation for the BERT model [here](https://huggingface.co/transformers/model_doc/bert.html) and a general usage guide [here](https://huggingface.co/transformers/v2.9.1/quickstart.html).

What we're going to do is *fine-tune* the BERT model, i.e., update the weights of a pre-trained model. That is, we have a model that is pre-trained on masked language modeling and next sentence prediction (kind of basic, general tasks which are useful for a lot of more specific tasks), but now we apply it to word sense disambiguation with the word representations it has learned.

We'll use the same data splits for training and testing as before, but this time you will use a different dataloader.

Now you create an iterator that collects N sentences (where N is the batch size) then use the BertTokenizer to transform the sentence into integers. For your dataloader, remember to:
* Shuffle the data in each batch
* Make sure you get a new iterator for each *epoch*
* Create a vocabulary of *sense-labels* so you can calculate accuracy 

We then pass this batch into the BERT model (you must have pre-loaded its weights) and update the weights (fine-tune). The BERT model will encode the sentence, then we send this encoded sentence into a prediction layer and collect what it outputs.

As input to the prediction layer, you are free to play with different types of information. For example, the expected way would be to use CLS representation. You can also use other representations and compare them.

About the hyperparameters and training:
* For BERT, usually a lower learning rate works best, between 0.0001-0.000001.
* BERT takes a lot of resources, running it on CPU will take ages, utilize the GPUs :)
* Since BERT takes a lot of resources, use a small batch size (4-8)
* Computing the BERT representation, make sure you pass the mask

**[12 marks]**

In [5]:
### 0. define hyperparameters
learning_rate = 1e-5 # For BERT, usually a lower learning rate works best, between 0.0001-0.000001.
epochs = 3
batch_size = 8 # Since BERT takes a lot of resources, use a small batch size (4-8)

In [6]:
#!pip install transformers

In [7]:
### 1. create a new dataloader (use the BertTokenizer to transform the sentence into integers)
from transformers import BertTokenizer, BertModel

## 1.1 create data preprocessing class
class CustomBertDataset(Dataset):
    """
    A custom PyTorch dataset for word sense disambiguation using BERT.

    Attributes:
        data (DataFrame): The DataFrame containing the dataset.
        tokenizer (PreTrainedTokenizer): The BERT tokenizer.
        max_length (int): The maximum length for tokenized sequences.
    
    Args:
        data (DataFrame): The DataFrame containing the dataset.
        tokenizer (PreTrainedTokenizer): The BERT tokenizer.
        label_encoder (LabelEncoder): An instance of sklearn's LabelEncoder for encoding word senses.
        max_length (int): The maximum length for tokenized sequences.
    """
    def __init__(self, data, tokenizer, label_encoder, max_length=128):
        """
        Initializes the dataset with data, tokenizer, and maximum sequence length.

        Args:
            data (DataFrame): The DataFrame containing the dataset.
            tokenizer (PreTrainedTokenizer): The BERT tokenizer.
            label_encoder (LabelEncoder): An instance of sklearn's LabelEncoder for encoding word senses.
            max_length (int): The maximum length for tokenized sequences (default is 128).
        """
        self.data = data
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.max_length = max_length

    def __len__(self):
        """
        Returns the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves a sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            dict: A dictionary containing the sample data with keys 'input_ids', 
                  'attention_mask', and 'word_sense'.
                  - 'input_ids' (torch.Tensor): The token IDs of the input context.
                  - 'attention_mask' (torch.Tensor): The attention mask for the input context.
                  - 'word_sense' (int): The encoded word sense label.
        """
        row = self.data.iloc[idx]
        word_sense = row['word_sense']
        word_form = row['word_form']
        index = row['index']
        context = row['context']

        # Encode word sense
        encoded_word_sense = self.label_encoder.transform([word_sense])[0]
        
        # Tokenize the context
        encoded_dict = self.tokenizer(
            text=context,
            add_special_tokens=True, # Add special tags [CLS] and [SEP]
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt' # Return PyTorch tensors
        )

        # Flatten to 1D tensors
        input_ids = encoded_dict['input_ids'].flatten()
        attention_mask = encoded_dict['attention_mask'].flatten()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'word_sense': int(encoded_word_sense)
        }

In [8]:
## 1.2 def dataloader function
def dataloader_for_bert(batch_size, shuffle=True):
    """
    Creates a DataLoader for a dataset suitable for BERT-based word sense disambiguation.

    Args:
        batch_size (int): The number of samples per batch to load.
        shuffle (bool): Whether to shuffle the data at every epoch. Default is True.

    Returns:
        tuple: A tuple containing the training data loader, testing data loader, and the output dimension.
    """
    # Load dataset
    print("load split dataset...")
    train_data = pd.read_csv('train_split.csv', sep=',', header=0) 
    test_data = pd.read_csv('test_split.csv', sep=',', header=0)

    train_data.columns = ['word_sense', 'word_form', 'index', 'context']
    test_data.columns = ['word_sense', 'word_form', 'index', 'context']
    
    # Initialize the BERT tokenizer
    print("initialize the BERT tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Encode word senses
    print("encode word senses...")
    combined_data = pd.concat([train_data, test_data]) 
    label_encoder = LabelEncoder()
    label_encoder.fit(combined_data['word_sense'])
    output_dim = len(label_encoder.classes_)
    print("number of distinguished word senses(output_dim):", output_dim)
    
    # Preprocess data
    print("preprocess data...")
    train_dataset = CustomBertDataset(train_data, tokenizer, label_encoder)
    test_dataset = CustomBertDataset(test_data, tokenizer, label_encoder)

    # Load batches
    print("load batches...")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print("loaded!")

    return train_loader, test_loader, output_dim


In [9]:
# load data
train_loader, test_loader, output_dim = dataloader_for_bert(batch_size=batch_size, shuffle=True)

load split dataset...
initialize the BERT tokenizer...
encode word senses...
number of distinguished word senses(output_dim): 222
preprocess data...
load batches...
loaded!


In [10]:
### 2. build model
## 2.1 create BERT_WSD class
class BERT_WSD(nn.Module):
    """
    A custom PyTorch model for word sense disambiguation using BERT.

    This model uses a pretrained BERT model to generate representations for the input text,
    and a linear classifier to predict the word sense.

    Attributes:
        bert (BertModel): The pretrained BERT model.
        classifier (nn.Linear): The linear layer for classification.

    Args:
        output_dim (int): The number of output classes for the classifier.
    """
    def __init__(self, output_dim):
        """
        Initializes the BERT_WSD model with a pretrained BERT model and a classifier.

        Args:
            output_dim (int): The number of output classes for the classifier.
        """
        super(BERT_WSD, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, output_dim)
    
    def forward(self, input_ids, attention_mask):
        """
        Forward pass of the BERT_WSD model.

        Args:
            input_ids (torch.Tensor): Tensor of input IDs for BERT (tokenized text).
            attention_mask (torch.Tensor): Tensor of attention masks for BERT.

        Returns:
            torch.Tensor: The output predictions from the classifier.
        """
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token representation => target indice?
        predictions = self.classifier(cls_output)
        return predictions

In [11]:
## 2.2 initialize model, loss function, and optimizer
model = BERT_WSD(output_dim=output_dim)
model.to(device)
# ensure that all the parameters require gradients
for param in model.parameters():
    param.requires_grad = True
    
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 2.3 training
print("fine-tuning BERT model...")
for epoch in range(epochs):
    total_loss = 0.0
    
    # train model
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['word_sense'].to(device)

        # reset gradient
        optimizer.zero_grad()

        # forward
        outputs = model(input_ids, attention_mask)

        # calculate loss
        loss = loss_function(outputs, labels)

        # Backpropagation
        loss.backward()

        # update parameters
        optimizer.step()

        # accumulate losses
        total_loss += loss.item()
    
    # calculate average loss
    avg_loss = total_loss / len(train_loader)
    print(f"epoch {epoch + 1}, average loss: {avg_loss:.4f}")

print('training finished!')


fine-tuning BERT model...
epoch 1, average loss: 2.1864
epoch 2, average loss: 0.9891
epoch 3, average loss: 0.7455
training finished!


In [13]:
### 3. test model after all epochs are completed
model.eval()
correct_predictions = 0
total_predictions = 0

best_bert_accuracy = 74.76 # set the threshold value when best_model has existed
best_bert_model_path = 'best_bert_model.pickle'

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['word_sense'].to(device)
        
        outputs = model(input_ids, attention_mask)
        _, predicted = torch.max(outputs, 1)
        
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

accuracy = 100 * correct_predictions / total_predictions
print(f'Test Accuracy: {accuracy:.2f}%')

if accuracy > best_bert_accuracy:
    best_bert_accuracy = accuracy
    torch.save(model.state_dict(), best_bert_model_path)
    print("Best model saved!")

## record of final accuracy on test set
# learning_rate = 1e-5, epochs = 3, batch_size = 4 => 74.76%
# learning_rate = 1e-5, epochs = 3, batch_size = 8 => 75.19%

Test Accuracy: 75.19%
Best model saved!


# 3. Evaluation

Q: Explain the difference between the two LSTMs that you have implemented for word sense disambiguation.

Important note: your LSTMs should be nearly the same, but your linear layer must take different inputs. Describe why and how you think this difference will affect the performance of different LSTMs. How does the contextual representation of the whole sequence perform? How does the representation of the target word perform? What is better and for what situations? Why do we observe these differences?

What kind of representations are the different approaches using to predict word senses?

**[4 marks]**

Bi-LSTM models process the input sequence in both forward and backward directions, allowing them to capture contextual information from both past and future tokens. For Bi-LSTMs, the final hidden state often serves as the representation of the whole sequence. 

Representations the different approaches using to predict word senses:- Approach 1: Uses the hidden state of the target word (local context).
- Approach 2: Uses the final hidden state of the sequence (global context).

Comparing different approaches:
- Local Context Approach (Approach 1): Better when the word sense is highly dependent on nearby words (e.g., syntactic dependencies, local collocations).
- Global Context Approach (Approach 2):  More suitable when broader sentence or document-level informatin influence the word sense.

From the results above, the accuracy of Approach 1 is around *73%*, while the result of Approach 2 is around *58%*. For our context here is not so long, it is rational to get a **worse** result using Approach 2. Compared with the baseline result (around *30%*), Approach 2 model could perform better.

Q: Evaluate your model with per-word form *accuracy* and comment on the results you get. How does the model perform in comparison to the baseline, and how do the models compare to each other? 

Expand on the evaluation by sorting the word-forms by the number of senses they have. Are word forms with fewer senses easier to predict? Give a short explanation of the results you get based on the number of senses per word.

**[4 marks]**

In [13]:
# baseline
sorted_baseline_per_word_form_accuracy_1 = dict(sorted(baseline_per_word_form_accuracy.items(), key=lambda item: item[1]['accuracy'], reverse=True))
sorted_baseline_per_word_form_accuracy_2 = dict(sorted(baseline_per_word_form_accuracy.items(), key=lambda item: item[1]['sense_count'], reverse=True))
output_baseline1 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_baseline_per_word_form_accuracy_1.items()])
output_baseline2 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_baseline_per_word_form_accuracy_2.items()])

print("Evaluation of MSC Baseline:")
print(f"per_word_form_accuracy(sorted by weighted accuracy): {output_baseline1}") 
print(f"per_word_form_accuracy(sorted by the number of senses): {output_baseline2}")

Evaluation of MSC Baseline:
per_word_form_accuracy(sorted by weighted accuracy): 'professional.a': 0.20%, 5; 'critical.a': 0.20%, 5; 'positive.a': 0.20%, 5; 'point.n': 0.12%, 8; 'life.n': 0.11%, 9; 'line.n': 0.09%, 11; 'see.v': 0.09%, 11; 'lead.v': 0.00%, 8; 'common.a': 0.00%, 4; 'find.v': 0.00%, 10; 'case.n': 0.00%, 8; 'security.n': 0.00%, 7; 'keep.v': 0.00%, 11; 'time.n': 0.00%, 5; 'build.v': 0.00%, 10; 'place.n': 0.00%, 7; 'physical.a': 0.00%, 6; 'major.a': 0.00%, 4; 'bad.a': 0.00%, 4; 'hold.v': 0.00%, 11; 'regular.a': 0.00%, 8; 'bring.v': 0.00%, 8; 'force.n': 0.00%, 8; 'serve.v': 0.00%, 9; 'follow.v': 0.00%, 11; 'extend.v': 0.00%, 7; 'national.a': 0.00%, 6; 'position.n': 0.00%, 6; 'order.n': 0.00%, 5; 'active.a': 0.00%, 5
per_word_form_accuracy(sorted by the number of senses): 'line.n': 0.09%, 11; 'keep.v': 0.00%, 11; 'hold.v': 0.00%, 11; 'see.v': 0.09%, 11; 'follow.v': 0.00%, 11; 'find.v': 0.00%, 10; 'build.v': 0.00%, 10; 'life.n': 0.11%, 9; 'serve.v': 0.00%, 9; 'lead.v': 0.00%, 8

In [25]:
# regenerate the result dict (sort by the number of senses)
sorted_per_word_form_accuracy_11 = dict(sorted(per_word_form_accuracy_1.items(), key=lambda item: item[1]['sense_count'], reverse=True))
output_11 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_per_word_form_accuracy_11.items()])

sorted_per_word_form_accuracy_22 = dict(sorted(per_word_form_accuracy_2.items(), key=lambda item: item[1]['sense_count'], reverse=True))
output_22 = '; '.join([f"'{word}': {stats['accuracy']:.2f}%, {stats['sense_count']}" for word, stats in sorted_per_word_form_accuracy_22.items()])

print("Evaluation of Approach 1:")
print(f"per_word_form_accuracy(sorted by weighted accuracy): {output_1}") 
print(f"per_word_form_accuracy(sorted by the number of senses): {output_11}")
print("\nEvaluation of Approach 2:")
print(f"per_word_form_accuracy(sorted by weighted accuracy): {output_2}") 
print(f"per_word_form_accuracy(sorted by the number of senses): {output_22}") 

Evaluation of Approach 1:
per_word_form_accuracy(sorted by weighted accuracy): 'bad': 19.98%, 4; 'professional': 16.53%, 5; 'major': 16.15%, 4; 'common': 16.08%, 4; 'active': 15.68%, 5; 'order': 15.66%, 5; 'critical': 15.14%, 5; 'time': 14.49%, 5; 'positive': 14.32%, 5; 'national': 12.72%, 6; 'security': 12.31%, 7; 'physical': 11.71%, 6; 'position': 11.01%, 6; 'place': 11.01%, 7; 'force': 10.28%, 8; 'point': 9.70%, 8; 'extend': 8.99%, 7; 'line': 8.80%, 11; 'regular': 8.62%, 8; 'life': 8.23%, 9; 'bring': 7.85%, 8; 'case': 7.79%, 8; 'see': 7.42%, 11; 'serve': 7.40%, 9; 'keep': 7.23%, 11; 'find': 6.72%, 10; 'lead': 6.71%, 8; 'hold': 6.29%, 11; 'follow': 5.93%, 11; 'build': 4.88%, 10
per_word_form_accuracy(sorted by the number of senses): 'line': 8.80%, 11; 'keep': 7.23%, 11; 'hold': 6.29%, 11; 'see': 7.42%, 11; 'follow': 5.93%, 11; 'find': 6.72%, 10; 'build': 4.88%, 10; 'life': 8.23%, 9; 'serve': 7.40%, 9; 'lead': 6.71%, 8; 'point': 9.70%, 8; 'case': 7.79%, 8; 'regular': 8.62%, 8; 'bring'

- Compared to the baseline, LSTM approaches perform better: even the worst per-word form accuracy value is larger than the best one in MSC baseline model.
- Both approaches show similar trends in per-word form accuracy, with some variations in specific words (e.g.,'major'), suggesting differences in how the models handle certain word forms or senses.
- Considering the weighted accuracy, words with fewer senses tend to have higher accuracy scores compared to those with more senses, indicating that disambiguating between fewer options is generally easier.(the method of weighing?)


Q: How do the LSTMs perform in comparison to BERT? What's the difference between representations obtained by the LSTMs and BERT?

**[4 marks]**

- From the training process and results, we can infer that BERT performs better than LSTMs while more time and computational resources are cost.

- Differences in Representations:
  - LSTMs: Generate contextual embeddings based on sequential processing. In bidirectional LSTMs, representations consider both directions, but still, the context captured is often less comprehensive than BERT’s due to sequential limitations.
  - BERT: Based on transformer architecture, BERT generates deep contextual embeddings for each token, taking into account the entire sentence bidirectionally. Pre-trained on large corpora, BERT's embeddings have a robust understanding of language nuances, allowing superior performance in tasks like word sense disambiguation.  

Q: What could we do to improve all WSD models that we have worked with in this assignment?

**[2 marks]**

Considering a general method of improvement on all WSD models, we suggest **enhanced data pre-processing** :  
- Fine-grained Tokenization: In this assignment, we just tokenize the iuput contexts by blank space, while a fine-grained tokenization may capture the nuances of the language better. This could include handling idiomatic phrases or subword-level tokens (e.g., BPE) more effectively.
- Balancing distribution: Here we noticed that the distribution of word senses is imbalanced, which possibly influence the WSD models' performance. Statistical sampling methods like Bootstrap may be helpful to improve the quality of data.  
- Noise Reduction: We could also clean the data to remove noise, such as irrelevant tokens or misspellings, that could confuse the model.

# Readings

[1] Kågebäck, M., & Salomonsson, H. (2016). Word Sense Disambiguation using a Bidirectional LSTM. arXiv preprint arXiv:1606.03568.

[2] ON WSD: https://web.stanford.edu/~jurafsky/slp3/slides/Chapter18.wsd.pdf

## Statement of contribution

Briefly state how many times you have met for discussions, who was present, to what degree each member contributed to the discussion and the final answers you are submitting.

## Marks

This assignment has a total of 46 marks.