In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW
import numpy as np
import pandas as pd

In [16]:
texts = ["The bank will not approve my loan.", "We sat on the river bank."]
target_words = ["bank", "bank"]
contexts = [(text, word) for text, word in zip(texts, target_words)]
contexts

[('The bank will not approve my loan.', 'bank'),
 ('We sat on the river bank.', 'bank')]

In [17]:
class ContextDataset(Dataset):
    def __init__(self, contexts):
        self.contexts = contexts
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        text, word = self.contexts[idx]
        inputs = self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
        word_tokens = self.tokenizer.tokenize(word)
        word_index = inputs["input_ids"].squeeze().tolist().index(self.tokenizer.convert_tokens_to_ids(word_tokens)[0])
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'word_index': word_index
        }

# Custom collate function
def custom_collate(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    word_indices = torch.tensor([item['word_index'] for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'word_index': word_indices
    }

dataset = ContextDataset(contexts)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=custom_collate)

# Iterate through the DataLoader
for i, batch in enumerate(dataloader):
    print(f"Batch {i+1}")
    print("Input IDs:\n", batch['input_ids'])
    print("Attention Masks:\n", batch['attention_mask'])
    print("Word Indices:\n", batch['word_index'])
    print("\n")
    # Optionally, limit the number of batches shown
    if i == 1:
        break


Batch 1
Input IDs:
 tensor([[  101,  1996,  2924,  2097,  2025, 14300,  2026,  5414,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,   

In [18]:
# Example sense embeddings for the word "bank"
sense_embeddings = {
    "financial_institution": np.random.rand(768).astype(np.float32),
    "river_side": np.random.rand(768).astype(np.float32)
}
sense_labels = np.array(list(sense_embeddings.keys()))
sense_tensor = torch.tensor([sense_embeddings[label] for label in sense_labels])
sense_labels

array(['financial_institution', 'river_side'], dtype='<U21')

In [19]:
import torch.nn as nn
import torch.optim as optim

In [20]:
# Load BERT model
model = BertModel.from_pretrained('bert-base-uncased')

# Define the loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)

# Example: {'bank': 0, 'bank': 1} means the first 'bank' corresponds to the first sense, the second 'bank' to the second sense
correct_sense_indices = {'bank': [0, 1]}  # This should come from your labeled dataset

# Training loop
for epoch in range(3):  # Train for 3 epochs
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        optimizer.zero_grad()

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        word_index = batch['word_index']
        print(input_ids.size())
        # Forward pass
        outputs = model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state

        # Get the word embeddings
        word_embeddings = hidden_states[range(hidden_states.size(0)), word_index].float()

        # Calculate the distance to each sense embedding
        distances = torch.cdist(word_embeddings.unsqueeze(0), sense_tensor.unsqueeze(0), p=2).squeeze()

        # Get the correct sense indices for the current batch
        batch_correct_sense_indices = [correct_sense_indices[target_words[idx]][batch_idx] for idx in range(len(word_index))]
        batch_correct_sense_indices = torch.tensor(batch_correct_sense_indices, dtype=torch.long)

        # Select the distances corresponding to the correct senses
        correct_sense_distances = distances[range(distances.size(0)), batch_correct_sense_indices]

        # Calculate the loss
        target = torch.zeros_like(correct_sense_distances)  # Target distances are zero
        loss = loss_fn(correct_sense_distances, target)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

torch.Size([2, 128])
Epoch 1, Loss: 483.49725341796875
torch.Size([2, 128])
Epoch 2, Loss: 410.11407470703125
torch.Size([2, 128])
Epoch 3, Loss: 363.18157958984375


In [21]:
def match_sense(text, word, model, tokenizer, sense_tensor, sense_labels):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
        outputs = model(**inputs)
        hidden_states = outputs.last_hidden_state

        # Get the word index
        word_tokens = tokenizer.tokenize(word)
        word_index = inputs["input_ids"].squeeze().tolist().index(tokenizer.convert_tokens_to_ids(word_tokens)[0])

        # Get the word embedding
        word_embedding = hidden_states[0, word_index, :].unsqueeze(0).float()

        # Calculate distances to sense embeddings
        distances = torch.cdist(word_embedding, sense_tensor.unsqueeze(0), p=2).squeeze()
        closest_sense_index = torch.argmin(distances).item()

    return sense_labels[closest_sense_index]

# Example usage
text = "The bank will not approve my loan."
word = "bank"
sense = match_sense(text, word, model, tokenizer, sense_tensor, sense_labels)
print(f"The sense of '{word}' in the context is: {sense}")


NameError: name 'tokenizer' is not defined