In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertModel, AdamW
import numpy as np
import pandas as pd
from tqdm import tqdm
import xml.etree.ElementTree as ET
import nltk
from nltk.corpus import wordnet as wn
from collections import namedtuple

In [2]:
import numpy as np

def generate_balls(num_layers, current_idx, current_center, current_radius, balls, dim):
    if num_layers == 0:
        return

    # Define the distance between the centers of two sibling balls
    separation = 2 * current_radius / 3

    # Choose a random direction for the first child ball's center
    direction = np.random.randn(dim)
    direction = direction / np.linalg.norm(direction)  # Normalize the direction vector

    # Calculate the centers for the two child balls
    center1 = current_center + direction * separation
    center2 = current_center - direction * separation

    # Calculate the radius for the child balls
    child_radius = current_radius / 3

    # Assign the child balls to the dictionary with updated indices
    left_idx = 2 * current_idx + 1
    right_idx = 2 * current_idx + 2
    balls[left_idx] = {'center': center1, 'radius': child_radius}
    balls[right_idx] = {'center': center2, 'radius': child_radius}

    # Recursive call for each child
    generate_balls(num_layers - 1, left_idx, center1, child_radius, balls, dim)
    generate_balls(num_layers - 1, right_idx, center2, child_radius, balls, dim)

# Root ball parameters
dim = 3
layer = 4
initial_radius = 10
initial_center = np.zeros(dim)

# Initialize the dictionary to hold all the balls
balls = {0: {'center': initial_center, 'radius': initial_radius}}

# Generate the balls recursively
generate_balls(layer, 0, initial_center, initial_radius, balls, dim)

# Check how many balls we have generated
ball_num = len(balls)
print(ball_num)

31


In [3]:
# Settings for loading training data
semcor_training_xml_path = 'WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.data.xml'
semcor_training_gk_path = 'WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.gold.key.txt'

In [4]:
def load_xml_data(xml_file_path=''):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    data = []
    for text in root.findall('text'):
        for sentence in text.findall('sentence'):
            sentence_id = sentence.get('id')
            sentence_text = ' '.join([element.text for element in sentence])
            for instance in sentence.findall('instance'):
                instance_id = instance.get('id')
                lemma = instance.get('lemma')
                pos = instance.get('pos')
                word = instance.text
                data.append([sentence_id, instance_id, lemma, pos, word, sentence_text])
    columns = ['sentence_id', 'instance_id', 'lemma', 'pos', 'word', 'sentence_text']
    xml_data = pd.DataFrame(data, columns=columns)
    return xml_data


# Load xml training data from semcor
semcor_training_xml = load_xml_data(semcor_training_xml_path)
display(semcor_training_xml.head())

Unnamed: 0,sentence_id,instance_id,lemma,pos,word,sentence_text
0,d000.s000,d000.s000.t000,long,ADJ,long,How long has it been since you reviewed the ob...
1,d000.s000,d000.s000.t001,be,VERB,been,How long has it been since you reviewed the ob...
2,d000.s000,d000.s000.t002,review,VERB,reviewed,How long has it been since you reviewed the ob...
3,d000.s000,d000.s000.t003,objective,NOUN,objectives,How long has it been since you reviewed the ob...
4,d000.s000,d000.s000.t004,benefit,NOUN,benefit,How long has it been since you reviewed the ob...


In [5]:
def load_gold_keys(gold_key_file_path=''):
    gold_key_data = []
    with open(gold_key_file_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            instance_id = parts[0]
            sense_id = parts[1]
            gold_key_data.append([instance_id, sense_id])

    # Create a DataFrame
    gold_key_columns = ['instance_id', 'sense_id']
    gold_key_df = pd.DataFrame(gold_key_data, columns=gold_key_columns)
    return gold_key_df

# Load gold key training data from semcor
semcor_training_gk = load_gold_keys(semcor_training_gk_path)
display(semcor_training_gk.head())

Unnamed: 0,instance_id,sense_id
0,d000.s000.t000,long%3:00:02::
1,d000.s000.t001,be%2:42:03::
2,d000.s000.t002,review%2:31:00::
3,d000.s000.t003,objective%1:09:00::
4,d000.s000.t004,benefit%1:21:00::


In [6]:
# Merge data from two file
semcor_training_merged = pd.merge(semcor_training_xml, semcor_training_gk, on='instance_id', how='inner')
display(semcor_training_merged.head())

Unnamed: 0,sentence_id,instance_id,lemma,pos,word,sentence_text,sense_id
0,d000.s000,d000.s000.t000,long,ADJ,long,How long has it been since you reviewed the ob...,long%3:00:02::
1,d000.s000,d000.s000.t001,be,VERB,been,How long has it been since you reviewed the ob...,be%2:42:03::
2,d000.s000,d000.s000.t002,review,VERB,reviewed,How long has it been since you reviewed the ob...,review%2:31:00::
3,d000.s000,d000.s000.t003,objective,NOUN,objectives,How long has it been since you reviewed the ob...,objective%1:09:00::
4,d000.s000,d000.s000.t004,benefit,NOUN,benefit,How long has it been since you reviewed the ob...,benefit%1:21:00::


In [7]:
sense_id_cache = {}
def format_sense_id(sense_id):
    if sense_id not in sense_id_cache:
        sense_id_cache[sense_id] = wn.lemma_from_key(sense_id).synset().name()
    return sense_id_cache[sense_id]
    
semcor_training_merged['formatted_sense_id'] = semcor_training_merged['sense_id'].apply(format_sense_id)

# We keep those columns for now
keys_to_keep = ['lemma', 'word', 'sentence_text', 'formatted_sense_id']
semcor_training_merged = semcor_training_merged[keys_to_keep]

display(semcor_training_merged.head())

Unnamed: 0,lemma,word,sentence_text,formatted_sense_id
0,long,long,How long has it been since you reviewed the ob...,long.a.01
1,be,been,How long has it been since you reviewed the ob...,be.v.01
2,review,reviewed,How long has it been since you reviewed the ob...,review.v.01
3,objective,objectives,How long has it been since you reviewed the ob...,aim.n.02
4,benefit,benefit,How long has it been since you reviewed the ob...,benefit.n.01


In [8]:
unique_sense_ids = semcor_training_merged['formatted_sense_id'].unique()
print(len(unique_sense_ids))
sampled_ids = pd.Series(unique_sense_ids).sample(n=ball_num, random_state=1)
print(len(sampled_ids))
semcor_training = semcor_training_merged[semcor_training_merged['formatted_sense_id'].isin(sampled_ids)]
print(len(semcor_training))

25916
31
229


In [9]:
# Setting for the model choice
models = {
    "BERT-Base": "bert-base-uncased",
    "BERT-Large": "bert-large-uncased",
    "BERT-Medium": "google/bert_uncased_L-8_H-512_A-8",
    "BERT-Small": "google/bert_uncased_L-4_H-256_A-4",
    "BERT-Mini": "google/bert_uncased_L-4_H-128_A-2",
    "BERT-Tiny": "google/bert_uncased_L-2_H-128_A-2"
}

# With our nball dimention 162, we choose bert small with 256 dimentions
model_name = models["BERT-Small"]

In [10]:
# Get sense index
sense_labels = list(sampled_ids)
sense_index_num = list(balls.keys())
sense_index = {sense: idx for idx, sense in zip(sense_index_num, sense_labels)}
semcor_training.loc[:,'sense_idx'] = semcor_training['formatted_sense_id'].map(sense_index)

original_dim = len(balls[sense_index_num[0]]['center'])
target_dim = 3  # Dimension of BERT-Small
padding_size = target_dim - original_dim  # Adjust padding size based on the new dimension

# Process each embedding
padded_embeddings = []
nball_radius = []
scalling_factor = 1
for index in sense_index_num:
    # Pad the trimmed embedding to match the target dimension
    padded_embedding = np.pad(balls[index]['center'], (0, padding_size), 'constant', constant_values=0)
    # Add to the list of padded embeddings
    padded_embeddings.append(scalling_factor*padded_embedding)
    nball_radius.append(scalling_factor*balls[index]['radius'])

sense_embeddings = torch.tensor(np.array(padded_embeddings), dtype=torch.float64)
nball_radius = torch.tensor(np.array(nball_radius), dtype=torch.float64)

print(f'Total data of nball embeddings:{len(sense_embeddings)}')
print(f'The length after padding: {len(sense_embeddings[0])}')

Total data of nball embeddings:31
The length after padding: 3


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[key] = value


In [11]:
# Tokenize the sentence
# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)

# Could be problem here, as we always fid
def find_word_index(sentence_ids, word):
    word_tokens = tokenizer.tokenize(word)
    word_ids = tokenizer.convert_tokens_to_ids(word_tokens)
    for i in range(len(sentence_ids) - len(word_tokens) + 1):
        if sentence_ids[i:i+len(word_tokens)].tolist() == word_ids:
            return i
    return -1

def tokenize_data(df):
    # Tokenize all sentences
    print("Tokenizing sentences...")
    tokenized_data = tokenizer(list(df['sentence_text']), padding=True, truncation=True, return_tensors="pt", max_length=512)
    input_ids = tokenized_data['input_ids']
    attention_masks = tokenized_data['attention_mask']

    # Progress bar for calculating word indices
    print("Calculating word indices...")
    # pbar = tqdm(total=df.shape[0], desc="Calculating word indices")
    word_indices = []
    for sentence_ids, word in zip(input_ids, df['word']):
        word_indices.append(find_word_index(sentence_ids, word))
        # pbar.update(1)  # Update progress for each word index found

    # print(f"Length input_ids:{len(input_ids)}\n Length attention_mask:{len(attention_masks)}\n Length \
    # word_index:{len(word_indices)}\n Length dataframe:{len(df)}")
    df.loc[:, 'input_ids'] = input_ids.tolist()
    df.loc[:, 'attention_mask'] = attention_masks.tolist()
    df.loc[:, 'word_index'] = word_indices

    print('Tokenizing finished!')
    # pbar.close()  # Close the progress bar after completion


tokenize_data(semcor_training)
display(semcor_training.head())

Tokenizing sentences...
Calculating word indices...
Tokenizing finished!


Unnamed: 0,lemma,word,sentence_text,formatted_sense_id,sense_idx,input_ids,attention_mask,word_index
985,hot,hot,Latest models serve hot meals at reasonable pr...,hot.a.01,11,"[101, 6745, 4275, 3710, 2980, 12278, 2012, 960...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",4
2702,partly,partly,The trouble was at least partly Juet 's doing .,partially.r.01,18,"[101, 1996, 4390, 2001, 2012, 2560, 6576, 1841...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ...",6
3176,experience,experience,"On the one hand , the major European nations h...",experience.n.03,20,"[101, 2006, 1996, 2028, 2192, 1010, 1996, 2350...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",66
6558,enclosure,enclosure,Sometimes these servants wrote or dictated for...,enclosure.n.02,9,"[101, 2823, 2122, 8858, 2626, 2030, 23826, 200...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",8
8084,hot,hot,"Eyes like hot honey , eyes that sizzled .",hot.a.01,11,"[101, 2159, 2066, 2980, 6861, 1010, 2159, 2008...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...",3


In [12]:
# Setting, small batch size here so that do not out of cuda memory
batch_size = 32

# Create tensor dataset to speed learning
# Convert lists of lists into tensors
all_input_ids = torch.stack([torch.tensor(ids, dtype=torch.long) for ids in semcor_training['input_ids']])
all_attention_masks = torch.stack([torch.tensor(mask, dtype=torch.long) for mask in semcor_training['attention_mask']])
all_word_indices = torch.tensor(semcor_training['word_index'].tolist(), dtype=torch.long)
all_senses = torch.tensor(semcor_training['sense_idx'].tolist(), dtype=torch.long)

# Create a TensorDataset
dataset = TensorDataset(all_input_ids, all_attention_masks, all_word_indices, all_senses)

# Use DataLoader to handle batching
dataloader = DataLoader(dataset, batch_size, shuffle=True)

In [13]:
print(nball_radius.median())
print(len(nball_radius))

tensor(0.1235, dtype=torch.float64)
31


In [14]:
# Training the model
# Loss
from torch.optim.lr_scheduler import StepLR
import math

# Set the device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# top_values, top_indices = torch.topk(nball_radius, 10)
# top_indices = top_indices.to(device)

def ball_inclusion_loss(word_embeddings, centers, radius, labels, epsilon=1e-7):
    distances = torch.norm(word_embeddings - centers, dim=1)
    # # Compute a binary mask for embeddings within the ball radius
    # # print(f"difference:{radius-distances}")
    within_ball = (distances <= radius).float()
    within_ball = torch.clamp(within_ball, epsilon, 1 - epsilon)
    loss = F.binary_cross_entropy(within_ball, labels.float(), reduction='mean')
    return loss, distances

def ball_inclution_loss_soft_margin(word_embeddings, centers, radius):
    distances = torch.norm(word_embeddings - centers, dim=1)
    soft_margin = torch.clamp(distances - radius, min=0)
    loss = torch.mean(soft_margin ** 2)
    return loss, distances
 

class CustomBertModel(nn.Module):
    def __init__(self, bert_model_name):
        super(CustomBertModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.projection = nn.Linear(self.bert.config.hidden_size, 3)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        projected_hidden_states = self.projection(hidden_states)
        return projected_hidden_states


# Model setting
# model = BertModel.from_pretrained(model_name).to(device)
model = CustomBertModel("google/bert_uncased_L-2_H-128_A-2").to(device)
model.train() 
sense_embeddings = sense_embeddings.to(device)  # Move sense embeddings to GPU
# loss_fn = ball_inclusion_loss
# loss_fn = nn.MSELoss()
loss_fn = ball_inclution_loss_soft_margin
optimizer = optim.Adam([
    {'params': model.parameters()}
    # {'params': sense_embeddings},
    # {'params': nball_radius}
], lr=2e-3)

scheduler = StepLR(optimizer, step_size=10, gamma=0.1)


# Training loop
num_epochs = 50
last_loss = 0
for epoch in range(num_epochs):
    total_loss = 0
    total_dis = 0
    total_radius = 0
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, position=0)
    for batch in progress_bar:
        # Send batch data to the device (GPU)
        batch_input_ids, batch_attention_masks, batch_word_indices, batch_sense_indices = [b.to(device) for b in batch]

        optimizer.zero_grad()
        
        # Forward pass
        # outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks)
        # hidden_states = outputs.last_hidden_state
        hidden_states = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks)
        # Retrieve embeddings for specific word indices
        word_embeddings = torch.stack([hidden_states[i, idx, :] for i, idx in enumerate(batch_word_indices)])
        
        # Retrieve the corresponding sense embeddings
        center = sense_embeddings[batch_sense_indices].to(dtype=torch.float32).to(device)
        radius = nball_radius[batch_sense_indices].to(device)

        # Labels tensor indicating that embeddings should be inside the ball
        labels = torch.ones(word_embeddings.size(0), device=device, requires_grad=True)
        # print("word_embeddings requires grad:", word_embeddings.requires_grad)
        # print("centers requires grad:", center.requires_grad)

        # # You can also check for the outputs from the model
        # print("Model labels requires grad:", labels.requires_grad)

        # Calculate loss
        # loss, dis = loss_fn(word_embeddings, center, radius, labels, epsilon=1e-2)
        loss, dis = loss_fn(word_embeddings, center, radius)
        # loss = loss_fn(word_embeddings, center)
        # dis = nn.MSELoss()(word_embeddings, center)
        # distances = torch.norm(word_embeddings - center, dim=1)
        # Backward pass
        loss.backward()

        # for name, param in model.named_parameters():
        #     if param.requires_grad and param.grad is not None:
        #         print(f"Gradient of {name}: {param.grad.abs().mean().item()}")
        #     elif param.requires_grad:
        #         print(f"No gradient for {name}")

        
        optimizer.step()
        
        total_loss += loss.item()
        total_dis += dis.sum().item()
        total_radius += radius.sum().item()
        progress_bar.set_postfix(loss=loss.item())

    
    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}, improve:{(last_loss - total_loss) / len(dataloader)}')
    print(f'distance:{total_dis / len(dataloader)}, radius:{total_radius / len(dataloader)}, differences:{(total_dis-total_radius) / len(dataloader)}')
    last_loss = total_loss

Using device: cuda


Epoch 1/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  6.31it/s, loss=7.24]


Epoch 1, Loss: 27.774158314282776, improve:-27.774158314282776
distance:160.05121088027954, radius:18.194444444444446, differences:141.85676643583508


Epoch 2/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.05it/s, loss=3.72]


Epoch 2, Loss: 7.437861846845153, improve:20.336296467437624
distance:73.268807888031, radius:18.194444444444446, differences:55.07436344358656


Epoch 3/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.87it/s, loss=2.49]


Epoch 3, Loss: 3.6296523213751186, improve:3.8082095254700343
distance:64.33247995376587, radius:18.194444444444446, differences:46.138035509321426


Epoch 4/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.13it/s, loss=1.51]


Epoch 4, Loss: 2.3117747907551296, improve:1.317877530619989
distance:49.003089904785156, radius:18.194444444444446, differences:30.80864546034071


Epoch 5/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.08it/s, loss=1.17]


Epoch 5, Loss: 1.6859347023719664, improve:0.6258400883831632
distance:45.12551212310791, radius:18.194444444444446, differences:26.931067678663464


Epoch 6/50: 100%|████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.88it/s, loss=0.142]


Epoch 6, Loss: 0.8923737839332646, improve:0.7935609184387018
distance:34.51242238283157, radius:18.194444444444446, differences:16.317977938387127


Epoch 7/50: 100%|█████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.74it/s, loss=1.03]


Epoch 7, Loss: 0.7525221226148792, improve:0.13985166131838545
distance:30.067126750946045, radius:18.194444444444446, differences:11.872682306501599


Epoch 8/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.97it/s, loss=0.0601]


Epoch 8, Loss: 0.5249439200620636, improve:0.22757820255281558
distance:27.442778885364532, radius:18.194444444444446, differences:9.248334440920086


Epoch 9/50: 100%|████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.36it/s, loss=0.212]


Epoch 9, Loss: 0.31464284767475675, improve:0.21030107238730683
distance:20.464122414588928, radius:18.194444444444443, differences:2.2696779701444854


Epoch 10/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.42it/s, loss=0.568]


Epoch 10, Loss: 0.3221906928565076, improve:-0.007547845181750834
distance:20.949612259864807, radius:18.194444444444443, differences:2.7551678154203643


Epoch 11/50: 100%|████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.58it/s, loss=0.15]


Epoch 11, Loss: 0.244926046097301, improve:0.07726464675920658
distance:20.342198312282562, radius:18.194444444444446, differences:2.147753867838116


Epoch 12/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.34it/s, loss=0.00639]


Epoch 12, Loss: 0.14782714252858006, improve:0.09709890356872095
distance:18.574446499347687, radius:18.194444444444446, differences:0.38000205490324035


Epoch 13/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.08it/s, loss=0.049]


Epoch 13, Loss: 0.14291960643602802, improve:0.0049075360925520395
distance:17.47724038362503, radius:18.194444444444443, differences:-0.7172040608194123


Epoch 14/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 42.99it/s, loss=0.0217]


Epoch 14, Loss: 0.12330030771446021, improve:0.019619298721567813
distance:17.23867779970169, radius:18.194444444444443, differences:-0.9557666447427522


Epoch 15/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.70it/s, loss=0.0207]


Epoch 15, Loss: 0.1203740900754646, improve:0.002926217638995615
distance:16.280475914478302, radius:18.194444444444446, differences:-1.9139685299661444


Epoch 16/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.12it/s, loss=0.153]


Epoch 16, Loss: 0.10787436136204911, improve:0.012499728713415481
distance:15.919270813465118, radius:18.194444444444443, differences:-2.2751736309793245


Epoch 17/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.36it/s, loss=0.0532]


Epoch 17, Loss: 0.094685576428816, improve:0.01318878493323311
distance:15.852753043174744, radius:18.19444444444445, differences:-2.3416914012697063


Epoch 18/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 37.29it/s, loss=0.018]


Epoch 18, Loss: 0.07104108967097061, improve:0.02364448675784539
distance:14.934168696403503, radius:18.194444444444446, differences:-3.260275748040943


Epoch 19/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.27it/s, loss=0.0291]


Epoch 19, Loss: 0.06773449082903797, improve:0.0033065988419326425
distance:14.915937542915344, radius:18.194444444444446, differences:-3.278506901529102


Epoch 20/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 34.02it/s, loss=0.245]


Epoch 20, Loss: 0.07506831161387864, improve:-0.0073338207848406695
distance:14.191771149635315, radius:18.194444444444443, differences:-4.002673294809128


Epoch 21/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.43it/s, loss=0.054]


Epoch 21, Loss: 0.06508046396431155, improve:0.009987847649567089
distance:15.01462110877037, radius:18.194444444444446, differences:-3.179823335674076


Epoch 22/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 34.56it/s, loss=0.00949]


Epoch 22, Loss: 0.04625046435477125, improve:0.0188299996095403
distance:13.493551075458527, radius:18.194444444444446, differences:-4.70089336898592


Epoch 23/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 40.84it/s, loss=0.0517]


Epoch 23, Loss: 0.04439749267709939, improve:0.001852971677671865
distance:14.302472487092018, radius:18.194444444444443, differences:-3.8919719573524247


Epoch 24/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.71it/s, loss=0.0124]


Epoch 24, Loss: 0.03353376751361974, improve:0.010863725163479644
distance:13.148218497633934, radius:18.194444444444446, differences:-5.046225946810512


Epoch 25/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.14it/s, loss=0.0116]


Epoch 25, Loss: 0.048732732233104564, improve:-0.015198964719484821
distance:14.024336993694305, radius:18.194444444444446, differences:-4.170107450750141


Epoch 26/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.54it/s, loss=0.0627]


Epoch 26, Loss: 0.04569713907082391, improve:0.0030355931622806573
distance:13.191233843564987, radius:18.194444444444446, differences:-5.003210600879459


Epoch 27/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 45.60it/s, loss=0.0223]


Epoch 27, Loss: 0.04626580171202871, improve:-0.0005686626412048021
distance:13.629263669252396, radius:18.194444444444446, differences:-4.565180775192051


Epoch 28/50: 100%|███████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 44.81it/s, loss=0]


Epoch 28, Loss: 0.038225071083711384, improve:0.008040730628317325
distance:13.402493357658386, radius:18.19444444444445, differences:-4.791951086786064


Epoch 29/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 34.58it/s, loss=0.149]


Epoch 29, Loss: 0.047706441453819845, improve:-0.00948137037010846
distance:12.651562750339508, radius:18.194444444444446, differences:-5.542881694104938


Epoch 30/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 39.93it/s, loss=0.091]


Epoch 30, Loss: 0.06680024250857425, improve:-0.019093801054754403
distance:14.386580556631088, radius:18.194444444444446, differences:-3.807863887813358


Epoch 31/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 42.55it/s, loss=0.0587]


Epoch 31, Loss: 0.05117414321256929, improve:0.01562609929600496
distance:14.068081736564636, radius:18.194444444444443, differences:-4.126362707879807


Epoch 32/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 33.67it/s, loss=0.00441]


Epoch 32, Loss: 0.04418143308447935, improve:0.006992710128089934
distance:12.86346447467804, radius:18.194444444444446, differences:-5.330979969766407


Epoch 33/50: 100%|████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 37.11it/s, loss=0.000157]


Epoch 33, Loss: 0.023953528874333798, improve:0.020227904210145555
distance:11.938200682401657, radius:18.194444444444446, differences:-6.256243762042789


Epoch 34/50: 100%|████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 37.22it/s, loss=0.000193]


Epoch 34, Loss: 0.019020372252049823, improve:0.004933156622283975
distance:11.790087461471558, radius:18.194444444444446, differences:-6.404356982972889


Epoch 35/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 39.06it/s, loss=0.0222]


Epoch 35, Loss: 0.022493676092665685, improve:-0.003473303840615862
distance:12.08046606183052, radius:18.194444444444443, differences:-6.113978382613922


Epoch 36/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 34.69it/s, loss=0.019]


Epoch 36, Loss: 0.02249523437120121, improve:-1.5582785355267026e-06
distance:11.65102243423462, radius:18.194444444444446, differences:-6.543422010209827


Epoch 37/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.37it/s, loss=0.0447]


Epoch 37, Loss: 0.028165478740215973, improve:-0.005670244369014761
distance:12.226624131202698, radius:18.194444444444446, differences:-5.967820313241749


Epoch 38/50: 100%|██████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 35.65it/s, loss=0.0636]


Epoch 38, Loss: 0.024139089975124116, improve:0.004026388765091857
distance:12.04725307226181, radius:18.194444444444446, differences:-6.147191372182636


Epoch 39/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 43.05it/s, loss=0.00839]


Epoch 39, Loss: 0.024332482882035857, improve:-0.00019339290691174071
distance:12.059675604104996, radius:18.194444444444446, differences:-6.134768840339451


Epoch 40/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 34.29it/s, loss=0.00123]


Epoch 40, Loss: 0.021561109748646314, improve:0.002771373133389543
distance:11.748284876346588, radius:18.194444444444446, differences:-6.446159568097858


Epoch 41/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 38.85it/s, loss=0.00241]


Epoch 41, Loss: 0.027171086681065746, improve:-0.0056099769324194325
distance:12.038091614842415, radius:18.194444444444443, differences:-6.156352829602028


Epoch 42/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 40.83it/s, loss=0.00384]


Epoch 42, Loss: 0.020048395707399354, improve:0.0071226909736663915
distance:11.410722836852074, radius:18.194444444444446, differences:-6.783721607592373


Epoch 43/50: 100%|███████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 36.36it/s, loss=0]


Epoch 43, Loss: 0.020137077165880556, improve:-8.86814584812011e-05
distance:12.098889589309692, radius:18.194444444444446, differences:-6.095554855134754


Epoch 44/50: 100%|███████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 36.99it/s, loss=0.133]


Epoch 44, Loss: 0.03604886182215484, improve:-0.015911784656274283
distance:12.0089191198349, radius:18.194444444444443, differences:-6.185525324609543


Epoch 45/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 39.82it/s, loss=0.00509]


Epoch 45, Loss: 0.029102939362595642, improve:0.0069459224595591965
distance:12.460154592990875, radius:18.194444444444446, differences:-5.734289851453571


Epoch 46/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 35.64it/s, loss=0.00258]


Epoch 46, Loss: 0.014413507769980145, improve:0.014689431592615498
distance:11.408422231674194, radius:18.194444444444446, differences:-6.786022212770252


Epoch 47/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 42.91it/s, loss=0.00124]


Epoch 47, Loss: 0.02197619804637136, improve:-0.007562690276391216
distance:11.483596757054329, radius:18.194444444444443, differences:-6.710847687390114


Epoch 48/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 33.88it/s, loss=0.00806]


Epoch 48, Loss: 0.01812447097667686, improve:0.0038517270696944994
distance:11.479848027229309, radius:18.194444444444443, differences:-6.714596417215134


Epoch 49/50: 100%|████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 42.76it/s, loss=0.02]


Epoch 49, Loss: 0.020855560831372577, improve:-0.002731089854695716
distance:12.022155493497849, radius:18.194444444444446, differences:-6.172288950946598


Epoch 50/50: 100%|█████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 36.88it/s, loss=0.00469]

Epoch 50, Loss: 0.018154324748705347, improve:0.00270123608266723
distance:11.452143862843513, radius:18.194444444444443, differences:-6.742300581600929





In [15]:
# Assuming that 'dataloader' for evaluation is available as eval_dataloader

# Switch model to evaluation mode
eval_dataloader = dataloader
model.eval()

# Store evaluation results
eval_loss = 0
accuracy = 0  # Placeholder for accuracy calculation

not_correct = []
# Disable gradient computation for evaluation to save memory and computations
with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluation", leave=True, position=0):
        # Send batch data to the device (GPU)
        batch_input_ids, batch_attention_masks, batch_word_indices, batch_sense_indices = [b.to(device) for b in batch]

        # Forward pass
        hidden_states = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks)
        # hidden_states = outputs.last_hidden_state

        # Retrieve embeddings for specific word indices
        word_embeddings = torch.stack([hidden_states[i, idx, :] for i, idx in enumerate(batch_word_indices)])

        # Retrieve the corresponding sense embeddings and radius
        center = sense_embeddings[batch_sense_indices].to(dtype=torch.float32).to(device)
        radius = nball_radius[batch_sense_indices].to(device)

        # Labels tensor (usually true for all points in eval since you are evaluating the inclusion)
        labels = torch.ones(word_embeddings.size(0), device=device)

        # Calculate loss
        loss, dis = ball_inclusion_loss(word_embeddings, center, radius, labels)
        eval_loss += loss.item()

        # Optionally calculate accuracy or other metrics
        # For instance, if accuracy is based on correct inclusions
        predictions = (torch.norm(word_embeddings - center, dim=1) <= radius).float()
        accuracy += (predictions == labels).float().mean().item()
        incorrect_indices = batch_sense_indices[predictions != labels]
        not_correct.extend(incorrect_indices.cpu().tolist())

# Compute average loss and accuracy
eval_loss /= len(eval_dataloader)
accuracy /= len(eval_dataloader)

print(f"Average Evaluation Loss: {eval_loss}")
print(f"Accuracy: {accuracy * 100:.2f}%")

Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 125.66it/s]


Average Evaluation Loss: 9.305681765079498
Accuracy: 42.27%


In [16]:
uni_not_correct = set(not_correct)
nball_radius[list(uni_not_correct)]

tensor([0.3704, 0.3704, 0.1235, 0.1235, 0.1235, 0.1235, 0.3704, 0.3704, 0.1235,
        0.1235, 0.1235, 0.1235, 0.3704, 0.3704, 0.1235, 0.1235, 0.1235, 0.1235,
        0.1235, 0.1235, 0.1235, 0.1235], dtype=torch.float64)

In [17]:
nball_radius

tensor([10.0000,  3.3333,  3.3333,  1.1111,  1.1111,  0.3704,  0.3704,  0.1235,
         0.1235,  0.1235,  0.1235,  0.3704,  0.3704,  0.1235,  0.1235,  0.1235,
         0.1235,  1.1111,  1.1111,  0.3704,  0.3704,  0.1235,  0.1235,  0.1235,
         0.1235,  0.3704,  0.3704,  0.1235,  0.1235,  0.1235,  0.1235],
       dtype=torch.float64)