In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
import nltk
from nltk.corpus import wordnet as wn

In [2]:
# First importing data
semcor_training_xml = 'WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.data.xml'

tree = ET.parse(semcor_training_xml)
root = tree.getroot()
data = []
for text in root.findall('text'):
    for sentence in text.findall('sentence'):
        sentence_id = sentence.get('id')
        sentence_text = ' '.join([element.text for element in sentence])
        for instance in sentence.findall('instance'):
            instance_id = instance.get('id')
            lemma = instance.get('lemma')
            pos = instance.get('pos')
            word = instance.text
            data.append([sentence_id, instance_id, lemma, pos, word, sentence_text])

# Create a DataFrame
columns = ['sentence_id', 'instance_id', 'lemma', 'pos', 'word', 'sentence_text']
xml_data = pd.DataFrame(data, columns=columns)

semcor_training_gold_key = 'WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.gold.key.txt'

# Parse the gold key file
gold_data = []
with open(semcor_training_gold_key, 'r') as file:
    for line in file:
        parts = line.strip().split()
        instance_id = parts[0]
        sense_id = parts[1]
        gold_data.append([instance_id, sense_id])

# Create a DataFrame
gold_columns = ['instance_id', 'sense_id']
gold_df = pd.DataFrame(gold_data, columns=gold_columns)

merged_data = pd.merge(xml_data, gold_df, on='instance_id', how='inner')

def format_sense_id(sense_id):
    try:
        # Split the sense_id by '%'
        parts = sense_id.split('%')
        if len(parts) != 2:
            return None  # Invalid format
        
        lemma = parts[0]
        sense_info = parts[1].split(':')
        
        # Ensure there are enough parts
        if len(sense_info) < 2:
            return None  # Invalid format
        
        # Convert WordNet POS tags
        pos_mapping = {'1': 'n', '2': 'v', '3': 'a', '4': 'r'}
        pos = pos_mapping.get(sense_info[0], None)
        if not pos:
            return None  # Invalid POS
        
        # Combine to form a WordNet synset id
        synset_id = f"{lemma}.{pos}.{sense_info[2].zfill(2)}"
        return synset_id
    except Exception as e:
        # If any error occurs, return None
        return None

# def is_valid_sense(sense_id):
#     formatted_sense_id = format_sense_id(sense_id)
#     if not formatted_sense_id:
#         return False
#     try:
#         wn.synset(formatted_sense_id)
#         return True
#     except nltk.corpus.reader.wordnet.WordNetError:
#         return False

# Apply the function to the merged dataset
merged_data['formatted_sense_id'] = merged_data['sense_id'].apply(format_sense_id)
# merged_data['valid_sense'] = merged_data['formatted_sense_id'].apply(is_valid_sense)

def load_ball_embeddings(bFile=''):
    """
    :param bFile:
    :return:
    """
    print("loading balls....")
    nball_list = []
    with open(bFile, 'r') as w2v:
        for line in w2v.readlines():
            wlst = line.strip().split()
            nball_list.append(wlst[0])
    print(len(nball_list),' balls are loaded\n')
    return nball_list

nball_list = load_ball_embeddings('training_set\word2vec.txt')
nball_list.append('be.v.05')
nball_list.append('be.v.00')
small_set = merged_data[merged_data['formatted_sense_id'].isin(nball_list)]
print(len(small_set))
display(small_set.head())

loading balls....
7586  balls are loaded

1576


Unnamed: 0,sentence_id,instance_id,lemma,pos,word,sentence_text,sense_id,formatted_sense_id
52,d000.s008,d000.s008.t000,be,VERB,Are,"Are there other , cheaper communications techn...",be%2:42:00::,be.v.00
554,d000.s067,d000.s067.t010,be,VERB,are,But even if that other plant employs the same ...,be%2:42:00::,be.v.00
564,d000.s070,d000.s070.t002,be,VERB,are,In what section of the country are you located ?,be%2:42:05::,be.v.05
566,d000.s071,d000.s071.t000,be,VERB,Are,Are you in a rural or urban area ?,be%2:42:05::,be.v.05
893,d000.s110,d000.s110.t000,be,VERB,are,There are two sides of a coin for this decision .,be%2:42:00::,be.v.00


In [3]:
keys_to_keep = ['lemma', 'word', 'sentence_text', 'formatted_sense_id']
small_set = small_set[keys_to_keep]

def generate_nball_dict(nball_file=''):
    nball_embeddings = {
        "be.v.05": np.random.rand(128).astype(np.float32),
        "be.v.00": np.random.rand(128).astype(np.float32)
    }
    return nball_embeddings

nball_embeddings = generate_nball_dict()

In [4]:
small_set.head()

Unnamed: 0,lemma,word,sentence_text,formatted_sense_id
52,be,Are,"Are there other , cheaper communications techn...",be.v.00
554,be,are,But even if that other plant employs the same ...,be.v.00
564,be,are,In what section of the country are you located ?,be.v.05
566,be,Are,Are you in a rural or urban area ?,be.v.05
893,be,are,There are two sides of a coin for this decision .,be.v.00


In [5]:
# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-tiny")

# Tokenize sentences to get input_ids and attention masks
def tokenize_and_find_index(row):
    # Convert the sentence to string in case it's not (handling NaN or None)
    sentence = str(row['sentence_text'])
    word = str(row['word'])  # Ensure the target word is also a string
    
    # Tokenize the sentence
    tokens = tokenizer(sentence, padding='max_length', max_length=512, truncation=True, return_tensors="pt")
    input_ids = tokens['input_ids'][0]
    
    # Tokenize the target word separately to find its first occurrence in the sentence
    word_tokens = tokenizer.tokenize(word)
    
    # Find the start index of the first complete occurrence of the word tokens in the input_ids
    for i in range(len(input_ids) - len(word_tokens) + 1):
        if input_ids[i:i+len(word_tokens)].tolist() == tokenizer.convert_tokens_to_ids(word_tokens):
            return tokens['input_ids'], tokens['attention_mask'], i
    
    # If word not found, handle it appropriately, here returning -1
    return tokens['input_ids'], tokens['attention_mask'], -1


small_set[['input_ids', 'attention_mask', 'word_index']] = small_set.apply(tokenize_and_find_index, axis=1, result_type='expand')


In [6]:
sense_labels = list(nball_embeddings.keys())
sense_embeddings = torch.tensor(np.array([nball_embeddings[label] for label in sense_labels]), dtype=torch.float32)
sense_index = {sense: idx for idx, sense in enumerate(sense_labels)}

small_set['sense_idx'] = small_set['formatted_sense_id'].map(sense_index)


In [7]:
# small_set.head()

In [8]:
from torch.utils.data import TensorDataset, DataLoader
# Assuming input_ids, attention_mask, and word_index have been properly processed
all_input_ids = torch.cat(small_set['input_ids'].tolist())
all_attention_masks = torch.cat(small_set['attention_mask'].tolist())
all_word_indices = torch.tensor(small_set['word_index'].tolist())
all_senses = torch.tensor(small_set['sense_idx'].tolist())

# Create a TensorDataset
dataset = TensorDataset(all_input_ids, all_attention_masks, all_word_indices, all_senses)

# Use DataLoader to handle batching
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


In [9]:
import torch.nn as nn
import torch.optim as optim

# Set the device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load BERT model
model = BertModel.from_pretrained("prajjwal1/bert-tiny").to(device)
model.train() 
sense_embeddings = sense_embeddings.to(device)  # Move sense embeddings to GPU

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        # Send batch data to the device (GPU)
        batch_input_ids, batch_attention_masks, batch_word_indices, batch_sense_indices = [b.to(device) for b in batch]

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks)
        hidden_states = outputs.last_hidden_state
        
        # Retrieve embeddings for specific word indices
        word_embeddings = torch.stack([hidden_states[i, idx, :] for i, idx in enumerate(batch_word_indices)])
        
        # Retrieve the corresponding sense embeddings
        target_embeddings = sense_embeddings[batch_sense_indices]

        # Calculate loss
        loss = loss_fn(word_embeddings, target_embeddings)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')



Using device: cuda
Epoch 1, Loss: 1.008372647356866
Epoch 2, Loss: 0.822199387447483
Epoch 3, Loss: 0.7519104360475153


In [13]:
import torch.nn.functional as F
model.eval()
prediction_dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

predicted_sense_ids = []

with torch.no_grad():
    for batch in prediction_dataloader:
        batch_input_ids, batch_attention_masks, batch_word_indices, _ = [b.to(device) for b in batch]

        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks)
        hidden_states = outputs.last_hidden_state
        
        word_embeddings = torch.stack([hidden_states[i, idx, :] for i, idx in enumerate(batch_word_indices)])
        
        # Calculate cosine similarities between word embeddings and sense embeddings
        cosine_similarities = F.cosine_similarity(word_embeddings.unsqueeze(1), sense_embeddings.unsqueeze(0), dim=-1)
        
        # Get the index of the closest sense embedding
        predicted_indices = torch.argmax(cosine_similarities, dim=1)
        
        predicted_sense_ids.extend(predicted_indices.cpu().numpy())

small_set['predicted_sense_id'] = [sense_labels[idx] for idx in predicted_sense_ids]

In [14]:
correct_predictions = small_set['formatted_sense_id'] == small_set['predicted_sense_id']
accuracy = correct_predictions.mean()

print(f"Model Accuracy: {accuracy * 100:.2f}%")

Model Accuracy: 82.80%


In [16]:
# Add a new column to the DataFrame
selected_columns = small_set[['sentence_text', 'formatted_sense_id', 'predicted_sense_id']]
selected_data = selected_columns[selected_columns['formatted_sense_id'] != selected_columns['predicted_sense_id']]

# Display the filtered DataFrame
selected_data

Unnamed: 0,sentence_text,formatted_sense_id,predicted_sense_id
4593,"He did not , however , settle back into acquie...",be.v.00,be.v.05
6735,We got to one house where there were five sece...,be.v.05,be.v.00
7240,What had been the ambassador 's suite was now ...,be.v.00,be.v.05
7457,We were there at a moment when the situation i...,be.v.05,be.v.00
7773,"Mr. Keo , once a diplomat in Paris and Washing...",be.v.00,be.v.05
...,...,...,...
225398,As evening approached and Palmer finished his ...,be.v.00,be.v.05
225403,"But Palmer knew , as did everybody else at Aug...",be.v.00,be.v.05
225425,"On the final round at Pensacola , the luck of ...",be.v.00,be.v.05
225441,"It was a dismal , drizzly day but a good one o...",be.v.00,be.v.05
