In [1]:
import sys
sys.path.append('../src')
from util import load_balls, preprocess_data
from dataset import nballDataset
from model import nballBertNorm, nballBertDirection

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
# Load data
train_semcor_path = {
    "xml":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.gold.key.txt',}

eval_all_path = {
    "xml":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.gold.key.txt',}

nball_path = '../data/sample_entity/nball.txt'

nball = load_balls(nball_path)
train_semcor, lemma_pair = preprocess_data(train_semcor_path, nball)

loading balls....
43667 balls are loaded



In [3]:
# Problem about lemma
display(train_semcor.head())
print(len(lemma_pair))
print(len(nball.keys()))

Unnamed: 0,lemma,word,sentence_text,lemma_idx,formatted_sense_id,sense_idx,sense_group
0,objective,objectives,How long has it been since you reviewed the ob...,4849,aim.n.02,5720,"[5720, 11008]"
1,benefit,benefit,How long has it been since you reviewed the ob...,894,benefit.n.01,20775,"[1007, 14577, 20775]"
2,service,service,How long has it been since you reviewed the ob...,8210,service.n.05,11463,"[10367, 10754, 10769, 10898, 11463, 12006, 122..."
3,program,program,How long has it been since you reviewed the ob...,3824,program.n.02,4552,"[4552, 8006, 9367, 11149]"
4,giveaway,giveaway,Have you permitted it to become a giveaway pro...,9705,giveaway.n.01,20555,"[12752, 20555]"


30165
43667


In [10]:
# Loading nball embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sense_labels = list(nball.keys())
nball_embeddings = [nball[label].center for label in sense_labels]
nball_norms = [nball[label].distance for label in sense_labels]
nball_radius = [nball[label].radius for label in sense_labels]

In [11]:
# Data checking:
# Sorting the lists
nball_norms_sorted = sorted(nball_norms)
nball_radius_sorted = sorted(nball_radius)

# Getting the smallest 10 from each list
smallest_10_norms = nball_norms_sorted[:10]
smallest_10_radius = nball_radius_sorted[:10]

# Getting the largest 10 from each list
largest_10_norms = nball_norms_sorted[-10:]
largest_10_radius = nball_radius_sorted[-10:]

# Printing the results
print("Smallest 10 elements in nball_norms:", smallest_10_norms)
print("Largest 10 elements in nball_norms:", largest_10_norms)

print("Smallest 10 elements in nball_radius:", smallest_10_radius)
print("Largest 10 elements in nball_radius:", largest_10_radius)

Smallest 10 elements in nball_norms: [1000000.0, 1000000.0, 1000000.0, 1000000.0, 1000000.0, 1000000.0, 1000000.0, 1030316.9252039694, 1030316.9252039694, 1030316.9252039694]
Largest 10 elements in nball_norms: [9.271626654035237e+260, 9.272199640406976e+260, 9.273105763720828e+260, 9.276812880845236e+260, 9.291212959815524e+260, 9.391083976630392e+260, 9.421969111692328e+260, 9.481569869909228e+260, 9.610572799981588e+260, 9.788260408452398e+260]
Smallest 10 elements in nball_radius: [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.00010303169252039694, 0.00010303169252039694, 0.00010303169252039694]
Largest 10 elements in nball_radius: [1.592188358137684e+259, 1.6796639919064024e+259, 2.0786414176209017e+259, 5.092086803778581e+259, 1.0510336007745004e+260, 1.1778868338387624e+260, 1.994060113116811e+260, 3.38293997519766e+260, 5.766343679988953e+260, 9.788260408452398e+260]


In [None]:
# Wrap as tensor
nball_embeddings = torch.tensor(np.array(nball_embeddings), dtype=torch.float32).to(device)
nball_norms = torch.tensor(np.array(nball_norms), dtype=torch.float32).to(device)
nball_radius = torch.tensor(np.array(nball_radius), dtype=torch.float32).to(device)

In [5]:
# Model list
bert_models = {"BERT-Base": ["bert-base-uncased", 768], #768, 12L, 12A
               "BERT-Large": ["bert-large-uncased", 1024], #1024, 24L, 16A
               "BERT-Medium": ["google/bert_uncased_L-8_H-512_A-8", 512], #512, 8L, 8A
               "BERT-Small": ["google/bert_uncased_L-4_H-256_A-4", 256], #256, 4L, 4A
               "BERT-Mini": ["google/bert_uncased_L-4_H-128_A-2", 128],#128, 4L, 2A
               "BERT-Tiny": ["google/bert_uncased_L-2_H-128_A-2", 128]}#128, 2L, 2A

max_length = 512
batch_size = 32
model_url = bert_models["BERT-Small"][0]

In [6]:
dataset_train = nballDataset(train_semcor, nball, model_url, max_length)

Tokenizing sentences...
Tokenizing finished.
Calculating word indices...
Tokenizing finished.


In [7]:
dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)

In [8]:
def train(model, dataloader, optimizer, scheduler, num_epochs, mode):
    last_loss = 0
    if mode == "norm":
        loss_fn = log_cosh_loss
    elif mode == "direction":
        loss_fn = nn.CosineEmbeddingLoss()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, position=0)
        for batch in progress_bar:
            # Unpack batch and send to device
            batch_input_ids, batch_attention_masks, batch_word_indices, \
            batch_sense_indices, _, _ = [b.to(device) for b in batch]

            optimizer.zero_grad()
        
            # Forward pass
            output = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                            word_index=batch_word_indices).squeeze()

            # print(output)
            if mode == "norm":
                batch_norms = nball_norms[batch_sense_indices]
                loss = loss_fn(output, batch_norms * 1e-5)
            elif mode == "direction":
                labels = torch.ones(output.size(0), device=device)
                batch_senses = nball_embeddings[batch_sense_indices]
                loss = loss_fn(output, batch_senses, labels)
                
            loss.backward()
            optimizer.step()
        
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        # Logging average metrics per epoch
        avg_loss = total_loss / len(dataloader)
        improvement = (last_loss - total_loss) / len(dataloader) if last_loss != 0 else 0
        print(f'Epoch {epoch + 1}, Avg Loss: {avg_loss}, Improvement: {improvement}')
        last_loss = total_loss

        scheduler.step()

In [9]:
# Train norm
def log_cosh_loss(norm_u, norm_v):
    return torch.log(torch.cosh(norm_u - norm_v)).mean()

output_dim = 160
model_norm = nballBertNorm(model_url, output_dim).to(device)
loss_fn_norm = log_cosh_loss
optimizer_norm = optim.Adam([
    {'params': model_norm.parameters()}
], lr=2e-5)
scheduler_norm = StepLR(optimizer_norm, step_size=50, gamma=0.1)

train(model=model_norm, \
      dataloader=dataloader_train,\
      optimizer=optimizer_norm,\
      scheduler=scheduler_norm,\
      num_epochs=50, \
      mode="norm")

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Epoch 1/50: 100%|████████████████████████████████████████████████████████| 2506/2506 [01:58<00:00, 21.18it/s, loss=nan]


Epoch 1, Avg Loss: nan, Improvement: 0


Epoch 2/50: 100%|████████████████████████████████████████████████████████| 2506/2506 [02:00<00:00, 20.83it/s, loss=nan]


Epoch 2, Avg Loss: nan, Improvement: nan


Epoch 3/50: 100%|████████████████████████████████████████████████████████| 2506/2506 [02:08<00:00, 19.54it/s, loss=nan]


Epoch 3, Avg Loss: nan, Improvement: nan


Epoch 4/50: 100%|████████████████████████████████████████████████████████| 2506/2506 [02:03<00:00, 20.35it/s, loss=nan]


Epoch 4, Avg Loss: nan, Improvement: nan


Epoch 5/50:  82%|██████████████████████████████████████████████          | 2062/2506 [01:47<00:23, 19.27it/s, loss=nan]


KeyboardInterrupt: 

In [10]:
# Train direction
model_d = nballBertDirection(model_url, output_dim).to(device)
optimizer_d = optim.Adam([
    {'params': model_d.parameters()}
], lr=2e-5)
scheduler_d = StepLR(optimizer_d, step_size=50, gamma=0.1)

train(model=model_d, \
      dataloader=dataloader_train,\
      optimizer=optimizer_d,\
      scheduler=scheduler_d,\
      num_epochs=20, \
      mode="direction")

Epoch 1/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.40it/s, loss=0.656]


Epoch 1, Avg Loss: 0.7961865121668036, Improvement: 0


Epoch 2/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 12.89it/s, loss=0.422]


Epoch 2, Avg Loss: 0.5248720320788297, Improvement: 0.2713144800879739


Epoch 3/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.52it/s, loss=0.268]


Epoch 3, Avg Loss: 0.331489080732519, Improvement: 0.1933829513463107


Epoch 4/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 12.91it/s, loss=0.176]


Epoch 4, Avg Loss: 0.21626012433658948, Improvement: 0.1152289563959295


Epoch 5/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.26it/s, loss=0.116]


Epoch 5, Avg Loss: 0.1408380540934476, Improvement: 0.07542207024314186


Epoch 6/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.19it/s, loss=0.0778]


Epoch 6, Avg Loss: 0.09525820680639961, Improvement: 0.045579847287047996


Epoch 7/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.94it/s, loss=0.0576]


Epoch 7, Avg Loss: 0.06695131280205467, Improvement: 0.02830689400434494


Epoch 8/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.72it/s, loss=0.0456]


Epoch 8, Avg Loss: 0.05028599839318882, Improvement: 0.01666531440886584


Epoch 9/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 12.36it/s, loss=0.0362]


Epoch 9, Avg Loss: 0.0400114387951114, Improvement: 0.010274559598077427


Epoch 10/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.38it/s, loss=0.0338]


Epoch 10, Avg Loss: 0.03469809449531815, Improvement: 0.005313344299793243


Epoch 11/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.89it/s, loss=0.0293]


Epoch 11, Avg Loss: 0.030054971575737, Improvement: 0.004643122919581153


Epoch 12/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.70it/s, loss=0.0268]


Epoch 12, Avg Loss: 0.027172402902082962, Improvement: 0.002882568673654036


Epoch 13/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.73it/s, loss=0.0249]


Epoch 13, Avg Loss: 0.025503649630329826, Improvement: 0.0016687532717531378


Epoch 14/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.56it/s, loss=0.0222]


Epoch 14, Avg Loss: 0.023086079650304535, Improvement: 0.0024175699800252914


Epoch 15/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.61it/s, loss=0.0211]


Epoch 15, Avg Loss: 0.022033502771095795, Improvement: 0.001052576879208738


Epoch 16/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.73it/s, loss=0.0189]


Epoch 16, Avg Loss: 0.020354048602960327, Improvement: 0.0016794541681354697


Epoch 17/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.08it/s, loss=0.0186]


Epoch 17, Avg Loss: 0.019648710265755653, Improvement: 0.000705338337204673


Epoch 18/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.86it/s, loss=0.0193]


Epoch 18, Avg Loss: 0.019208111715587704, Improvement: 0.0004405985501679507


Epoch 19/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.26it/s, loss=0.0172]


Epoch 19, Avg Loss: 0.017973837358030407, Improvement: 0.0012342743575572968


Epoch 20/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.45it/s, loss=0.0171]

Epoch 20, Avg Loss: 0.017240411686626347, Improvement: 0.0007334256714040583





In [12]:
# Load tree structure
wsChildrenFile = '../data/sample_mammal/wsChildren.txt'
def load_tree(wsChildrenFile = None):
    """
    Read the file of word2vec and wsChildren, save them into dictionary

    :param word2vecFile:path of word2vecFile
    :param wsChildrenFile:path of wsChildrenFile
    :return:two dictionary of word2vecDic and wsChildrenDic
    """
    wsChildrenDic = dict()
    with open(wsChildrenFile, 'r') as chfh:
        for ln in chfh:
            wlst = ln[:-1].split()
            wsChildrenDic[wlst[0]] = wlst[1:]
    return wsChildrenDic

tree = load_tree(wsChildrenFile)

In [26]:
label_to_index = {label: idx for idx, label in enumerate(sense_labels)}
index_to_label = {idx: label for idx, label in enumerate(sense_labels)}

def predict_with_tree(tree, output_emb, candidate_indices, candidate_norms, candidate_directions, candidate_radii, nball_norms, nball_embeddings, nball_radius, index_to_label, label_to_index):
    current_nodes = [index_to_label[idx] for idx in candidate_indices]
    visited_nodes = set()
    
    while current_nodes:
        distances = [F.pairwise_distance(output_emb.unsqueeze(0), (cand_dir * cand_norm.unsqueeze(-1)).unsqueeze(0)).item()
                     for cand_dir, cand_norm in zip(candidate_directions, candidate_norms)]
        within_radius_indices = [idx for idx, (distance, radius) in enumerate(zip(distances, candidate_radii)) if distance <= radius.item()]
        
        if within_radius_indices:
            return label_to_index[current_nodes[within_radius_indices[0]]]  # Return the first one within the radius
        
        visited_nodes.update(current_nodes)  # Add current nodes to visited set
        
        # Move to the parent nodes, avoiding revisiting nodes
        parent_nodes = set()
        for node in current_nodes:
            for parent, children in tree.items():
                if node in children and parent not in visited_nodes:
                    parent_nodes.add(parent)
        
        current_nodes = list(parent_nodes)
        candidate_norms = [nball_norms[label_to_index[node]] * 1e-5 for node in current_nodes]
        candidate_directions = [nball_embeddings[label_to_index[node]] for node in current_nodes]
        candidate_radii = [nball_radius[label_to_index[node]] for node in current_nodes]
    
    return None  # Return None if no suitable candidate is found

In [29]:
# Evaluation

def evaluation(model_n, model_d, dataloader):
    model_n.eval()
    model_d.eval()

    loss_fn_n = log_cosh_loss
    loss_fn_d = nn.CosineEmbeddingLoss()

    loss_n = 0.0
    loss_d = 0.0
    accuracy = 0.0
    correct_predictions = 0
    pred_indices = {}
    size = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluation", leave=True, position=0):
            batch_input_ids, batch_attention_masks, batch_word_indices, \
            batch_sense_indices, batch_lemma_indices, batch_idices = [b.to(device) for b in batch]

            # Forward pass
            norms = model_n(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                            word_index=batch_word_indices).squeeze()
            directions = model_d(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                                 word_index=batch_word_indices).squeeze()

            target_norms = nball_norms[batch_sense_indices]
            target_directions = nball_embeddings[batch_sense_indices]

            labels = torch.ones(directions.size(0), device=device)
            loss_n_batch = loss_fn_n(norms, target_norms * 1e-5)
            loss_d_batch = loss_fn_d(directions, target_directions, labels)
            loss_n += loss_n_batch.item()
            loss_d += loss_d_batch.item()

            # Predict
            output_embeddings = directions * norms.unsqueeze(-1)
            for i in range(len(output_embeddings)):
                size += 1
                output_emb = output_embeddings[i]
                candidate_indices = lemma_pair[batch_lemma_indices[i].item()]
                candidate_norms = [nball_norms[idx] * 1e-5 for idx in candidate_indices]
                candidate_directions = [nball_embeddings[idx] for idx in candidate_indices]
                candidate_radii = [nball_radius[idx] for idx in candidate_indices]
                predicted_index = predict_with_tree(tree, output_emb, candidate_indices, candidate_norms, candidate_directions, candidate_radii, nball_norms, nball_embeddings, nball_radius, index_to_label, label_to_index)
                pred_indices[batch_idices[i].item()] = predicted_index
                if predicted_index == batch_sense_indices[i].item():
                    correct_predictions += 1
                
        total_batches = len(dataloader)
        loss_n /= total_batches
        loss_d /= total_batches
        accuracy = correct_predictions / size

    return loss_n, loss_d, accuracy, pred_indices


In [30]:
loss_n, loss_d, accurancy, pred_indices = evaluation(model_n=model_norm, \
                                                     model_d=model_d,\
                                                     dataloader=dataloader_train)

print(f"Average loss of norm:{loss_n}, average loss of dorection:{loss_d}")
print(f"Accurancy:{accurancy*100}%")
train_semcor['pred_sense_idx'] = train_semcor.index.map(pred_indices)
display(train_semcor.head())

Evaluation: 100%|██████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 20.95it/s]

Average loss of norm:0.6490630046887831, average loss of dorection:0.002748901363123547
Accurancy:83.52272727272727%





Unnamed: 0,lemma,word,sentence_text,lemma_idx,formatted_sense_id,sense_idx,sense_group,pred_sense_idx
0,man,men,Whereas the eighteenth century had been a time...,131,homo.n.02,132,[132],132
1,horse,horse,One wrote : `` [ I am so hungry ] I could eat ...,255,horse.n.01,261,[261],261
2,wildcat,wildcat,"She is well-educated and refined , all wildcat...",53,wildcat.n.03,54,[54],54
3,lion,Lions,"To find out , we traveled throughout that part...",61,lion.n.01,62,[62],62
4,elephant,elephants,The Prince visited the hospital of Operation B...,103,elephant.n.01,104,[104],104


In [31]:
set_length = len(train_semcor)

count_single_element = (train_semcor['sense_group'].apply(lambda x: len(x) == 1)).sum()
mismatch_count = (train_semcor['sense_idx'] != train_semcor['pred_sense_idx']).sum()
count_multiple_element = set_length-count_single_element

print(f"Length of dataset:{set_length}")
print(f"Length of have only one element:{count_single_element}")
print(f"Length of mismatch:{mismatch_count}")
print(f"Original accurancy:{count_single_element / set_length *100}%")
print(f"Filtered accurancy:{(count_multiple_element-mismatch_count) /count_multiple_element  *100}%")

Length of dataset:352
Length of have only one element:330
Length of mismatch:58
Original accurancy:93.75%
Filtered accurancy:-163.63636363636365%
