In [1]:
import sys
sys.path.append('../src')
from util import load_balls, preprocess_data
from dataset import nballDataset
from model import nballBertNorm, nballBertDirection

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torch.nn.functional as F

OSError: [WinError 127] 找不到指定的程序。 Error loading "D:\Anaconda\envs\annotated-transformer\lib\site-packages\torch\lib\shm.dll" or one of its dependencies.

In [2]:
# Load data
train_semcor_path = {
    "xml":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.gold.key.txt',}

eval_all_path = {
    "xml":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.gold.key.txt',}

nball_path = '../data/sample_mammal/nball.txt'

nball = load_balls(nball_path)
train_semcor, lemma_pair = preprocess_data(train_semcor_path, nball)

loading balls....
388 balls are loaded



In [3]:
# Problem about lemma
display(train_semcor.head())
print(len(lemma_pair))
print(len(nball.keys()))

Unnamed: 0,lemma,word,sentence_text,lemma_idx,formatted_sense_id,sense_idx,sense_group
0,man,men,Whereas the eighteenth century had been a time...,131,homo.n.02,132,[132]
1,horse,horse,One wrote : `` [ I am so hungry ] I could eat ...,255,horse.n.01,261,[261]
2,wildcat,wildcat,"She is well-educated and refined , all wildcat...",53,wildcat.n.03,54,[54]
3,lion,Lions,"To find out , we traveled throughout that part...",61,lion.n.01,62,[62]
4,elephant,elephants,The Prince visited the hospital of Operation B...,103,elephant.n.01,104,[104]


375
388


In [4]:
# Loading nball embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sense_labels = list(nball.keys())
nball_embeddings = [nball[label].center for label in sense_labels]
nball_embeddings = torch.tensor(np.array(nball_embeddings), dtype=torch.float32).to(device)
nball_norms = [nball[label].distance for label in sense_labels]
nball_norms = torch.tensor(np.array(nball_norms), dtype=torch.float32).to(device)

In [5]:
# Model list
bert_models = {"BERT-Base": ["bert-base-uncased", 768], #768, 12L, 12A
               "BERT-Large": ["bert-large-uncased", 1024], #1024, 24L, 16A
               "BERT-Medium": ["google/bert_uncased_L-8_H-512_A-8", 512], #512, 8L, 8A
               "BERT-Small": ["google/bert_uncased_L-4_H-256_A-4", 256], #256, 4L, 4A
               "BERT-Mini": ["google/bert_uncased_L-4_H-128_A-2", 128],#128, 4L, 2A
               "BERT-Tiny": ["google/bert_uncased_L-2_H-128_A-2", 128]}#128, 2L, 2A

max_length = 512
batch_size = 32
model_url = bert_models["BERT-Small"][0]

In [6]:
dataset_train = nballDataset(train_semcor, nball, model_url, max_length)

Tokenizing sentences...
Tokenizing finished.
Calculating word indices...
Tokenizing finished.


In [7]:
dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)

In [8]:
def train(model, dataloader, optimizer, scheduler, num_epochs, mode):
    last_loss = 0
    if mode == "norm":
        loss_fn = log_cosh_loss
    elif mode == "direction":
        loss_fn = nn.CosineEmbeddingLoss()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, position=0)
        for batch in progress_bar:
            # Unpack batch and send to device
            batch_input_ids, batch_attention_masks, batch_word_indices, \
            batch_sense_indices, _, _ = [b.to(device) for b in batch]

            optimizer.zero_grad()
        
            # Forward pass
            output = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                            word_index=batch_word_indices).squeeze()

            # print(output)
            if mode == "norm":
                batch_norms = nball_norms[batch_sense_indices]
                loss = loss_fn(output, batch_norms * 1e-5)
            elif mode == "direction":
                labels = torch.ones(output.size(0), device=device)
                batch_senses = nball_embeddings[batch_sense_indices]
                loss = loss_fn(output, batch_senses, labels)
                
            loss.backward()
            optimizer.step()
        
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        # Logging average metrics per epoch
        avg_loss = total_loss / len(dataloader)
        improvement = (last_loss - total_loss) / len(dataloader) if last_loss != 0 else 0
        print(f'Epoch {epoch + 1}, Avg Loss: {avg_loss}, Improvement: {improvement}')
        last_loss = total_loss

        scheduler.step()

In [9]:
# Train norm
def log_cosh_loss(norm_u, norm_v):
    return torch.log(torch.cosh(norm_u - norm_v)).mean()

output_dim = 160
model_norm = nballBertNorm(model_url, output_dim).to(device)
loss_fn_norm = log_cosh_loss
optimizer_norm = optim.Adam([
    {'params': model_norm.parameters()}
], lr=2e-5)
scheduler_norm = StepLR(optimizer_norm, step_size=50, gamma=0.1)

train(model=model_norm, \
      dataloader=dataloader_train,\
      optimizer=optimizer_norm,\
      scheduler=scheduler_norm,\
      num_epochs=50, \
      mode="norm")

Epoch 1/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  5.65it/s, loss=6.61]


Epoch 1, Avg Loss: 6.756277517838911, Improvement: 0


Epoch 2/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.90it/s, loss=6.92]


Epoch 2, Avg Loss: 6.2669667330655185, Improvement: 0.4893107847733931


Epoch 3/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.94it/s, loss=5.65]


Epoch 3, Avg Loss: 5.971220168200406, Improvement: 0.2957465648651123


Epoch 4/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.92it/s, loss=6.39]


Epoch 4, Avg Loss: 5.760668017647483, Improvement: 0.21055215055292303


Epoch 5/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.39it/s, loss=6.39]


Epoch 5, Avg Loss: 5.542945775118741, Improvement: 0.217722242528742


Epoch 6/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.46it/s, loss=4.14]


Epoch 6, Avg Loss: 5.301914366808805, Improvement: 0.24103140830993652


Epoch 7/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.17it/s, loss=6.34]


Epoch 7, Avg Loss: 5.0364978313446045, Improvement: 0.26541653546420013


Epoch 8/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.31it/s, loss=3.99]


Epoch 8, Avg Loss: 4.745988737453114, Improvement: 0.2905090938914906


Epoch 9/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.35it/s, loss=6.01]


Epoch 9, Avg Loss: 4.501619122245095, Improvement: 0.24436961520801892


Epoch 10/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.20it/s, loss=5.47]


Epoch 10, Avg Loss: 4.3124634569341485, Improvement: 0.18915566531094638


Epoch 11/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.36it/s, loss=3.69]


Epoch 11, Avg Loss: 4.188019882548939, Improvement: 0.12444357438520952


Epoch 12/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.38it/s, loss=4.02]


Epoch 12, Avg Loss: 4.111698779192838, Improvement: 0.0763211033561013


Epoch 13/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.41it/s, loss=3.17]


Epoch 13, Avg Loss: 4.00158149545843, Improvement: 0.1101172837344083


Epoch 14/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.30it/s, loss=4.59]


Epoch 14, Avg Loss: 3.966487624428489, Improvement: 0.03509387102994052


Epoch 15/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.40it/s, loss=2.56]


Epoch 15, Avg Loss: 3.9252442771738227, Improvement: 0.04124334725466641


Epoch 16/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.46it/s, loss=3.71]


Epoch 16, Avg Loss: 3.906999176198786, Improvement: 0.01824510097503662


Epoch 17/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.35it/s, loss=3.65]


Epoch 17, Avg Loss: 3.854695515199141, Improvement: 0.05230366099964489


Epoch 18/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.53it/s, loss=1.58]


Epoch 18, Avg Loss: 3.8122210285880347, Improvement: 0.04247448661110618


Epoch 19/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.35it/s, loss=4.44]


Epoch 19, Avg Loss: 3.739484331824563, Improvement: 0.07273669676347212


Epoch 20/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.46it/s, loss=3.82]


Epoch 20, Avg Loss: 3.66066499189897, Improvement: 0.07881933992559259


Epoch 21/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.39it/s, loss=4.11]


Epoch 21, Avg Loss: 3.570920142260465, Improvement: 0.08974484963850542


Epoch 22/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.96it/s, loss=2.96]


Epoch 22, Avg Loss: 3.456747044216503, Improvement: 0.11417309804396196


Epoch 23/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.01it/s, loss=2.24]


Epoch 23, Avg Loss: 3.2905440547249536, Improvement: 0.1662029894915494


Epoch 24/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.88it/s, loss=2.09]


Epoch 24, Avg Loss: 3.19614932753823, Improvement: 0.09439472718672319


Epoch 25/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.17it/s, loss=3.38]


Epoch 25, Avg Loss: 3.0974092158404263, Improvement: 0.09874011169780385


Epoch 26/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.13it/s, loss=2.28]


Epoch 26, Avg Loss: 2.946654428135265, Improvement: 0.15075478770516135


Epoch 27/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.33it/s, loss=2.1]


Epoch 27, Avg Loss: 2.823667461221868, Improvement: 0.12298696691339667


Epoch 28/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.41it/s, loss=1.97]


Epoch 28, Avg Loss: 2.693918824195862, Improvement: 0.12974863702600653


Epoch 29/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.28it/s, loss=3.55]


Epoch 29, Avg Loss: 2.606867389245467, Improvement: 0.08705143495039507


Epoch 30/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.39it/s, loss=1.15]


Epoch 30, Avg Loss: 2.489155509255149, Improvement: 0.11771187999031761


Epoch 31/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.46it/s, loss=4.16]


Epoch 31, Avg Loss: 2.3992429158904334, Improvement: 0.08991259336471558


Epoch 32/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.34it/s, loss=1.7]


Epoch 32, Avg Loss: 2.322426915168762, Improvement: 0.07681600072167137


Epoch 33/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.19it/s, loss=2.01]


Epoch 33, Avg Loss: 2.2410902326757256, Improvement: 0.08133668249303644


Epoch 34/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.56it/s, loss=1.55]


Epoch 34, Avg Loss: 2.1906347166408193, Improvement: 0.05045551603490656


Epoch 35/50: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.40it/s, loss=4.2]


Epoch 35, Avg Loss: 2.1024714491584082, Improvement: 0.08816326748241078


Epoch 36/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.34it/s, loss=1.98]


Epoch 36, Avg Loss: 2.034699938514016, Improvement: 0.06777151064439253


Epoch 37/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.45it/s, loss=1.44]


Epoch 37, Avg Loss: 1.968110518022017, Improvement: 0.06658942049199884


Epoch 38/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.24it/s, loss=1.46]


Epoch 38, Avg Loss: 1.921114206314087, Improvement: 0.04699631170793013


Epoch 39/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.45it/s, loss=1.44]


Epoch 39, Avg Loss: 1.8320451758124612, Improvement: 0.08906903050162575


Epoch 40/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.40it/s, loss=1.33]


Epoch 40, Avg Loss: 1.7379827824505893, Improvement: 0.09406239336187189


Epoch 41/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.35it/s, loss=1.93]


Epoch 41, Avg Loss: 1.671442692930048, Improvement: 0.0665400895205411


Epoch 42/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.33it/s, loss=1.32]


Epoch 42, Avg Loss: 1.590345404364846, Improvement: 0.0810972885652022


Epoch 43/50: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.30it/s, loss=0.726]


Epoch 43, Avg Loss: 1.5506694696166299, Improvement: 0.03967593474821611


Epoch 44/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.49it/s, loss=1.22]


Epoch 44, Avg Loss: 1.4241426316174595, Improvement: 0.12652683799917047


Epoch 45/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.38it/s, loss=1.39]


Epoch 45, Avg Loss: 1.302644664591009, Improvement: 0.12149796702645042


Epoch 46/50: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.27it/s, loss=0.896]


Epoch 46, Avg Loss: 1.199525773525238, Improvement: 0.10311889106577093


Epoch 47/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.51it/s, loss=2.33]


Epoch 47, Avg Loss: 1.0946883179924705, Improvement: 0.10483745553276756


Epoch 48/50: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.29it/s, loss=0.78]


Epoch 48, Avg Loss: 0.9956685385920785, Improvement: 0.09901977940039201


Epoch 49/50: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.27it/s, loss=0.407]


Epoch 49, Avg Loss: 0.9077390784567053, Improvement: 0.0879294601353732


Epoch 50/50: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.52it/s, loss=0.179]

Epoch 50, Avg Loss: 0.8224796002561395, Improvement: 0.08525947820056569





In [10]:
# Train direction
model_d = nballBertDirection(model_url, output_dim).to(device)
optimizer_d = optim.Adam([
    {'params': model_d.parameters()}
], lr=2e-5)
scheduler_d = StepLR(optimizer_d, step_size=50, gamma=0.1)

train(model=model_d, \
      dataloader=dataloader_train,\
      optimizer=optimizer_d,\
      scheduler=scheduler_d,\
      num_epochs=20, \
      mode="direction")

Epoch 1/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.87it/s, loss=0.762]


Epoch 1, Avg Loss: 0.9002834450114857, Improvement: 0


Epoch 2/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.08it/s, loss=0.519]


Epoch 2, Avg Loss: 0.6333872188221324, Improvement: 0.26689622618935327


Epoch 3/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.37it/s, loss=0.346]


Epoch 3, Avg Loss: 0.4295814877206629, Improvement: 0.20380573110146957


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.20it/s, loss=0.24]


Epoch 4, Avg Loss: 0.2891637547449632, Improvement: 0.14041773297569968


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.26it/s, loss=0.16]


Epoch 5, Avg Loss: 0.192511802369898, Improvement: 0.0966519523750652


Epoch 6/20: 100%|██████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.37it/s, loss=0.107]


Epoch 6, Avg Loss: 0.12691779163750735, Improvement: 0.06559401073239067


Epoch 7/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.30it/s, loss=0.0715]


Epoch 7, Avg Loss: 0.08590154485269026, Improvement: 0.041016246784817086


Epoch 8/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.25it/s, loss=0.0515]


Epoch 8, Avg Loss: 0.06093714013695717, Improvement: 0.024964404715733093


Epoch 9/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.30it/s, loss=0.0412]


Epoch 9, Avg Loss: 0.04754723913290284, Improvement: 0.013389901004054329


Epoch 10/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.33it/s, loss=0.0374]


Epoch 10, Avg Loss: 0.038943444124676964, Improvement: 0.008603795008225874


Epoch 11/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.43it/s, loss=0.0319]


Epoch 11, Avg Loss: 0.03294910355047746, Improvement: 0.005994340574199503


Epoch 12/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.28it/s, loss=0.028]


Epoch 12, Avg Loss: 0.030301975763656876, Improvement: 0.0026471277868205852


Epoch 13/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.25it/s, loss=0.0256]


Epoch 13, Avg Loss: 0.02766367251222784, Improvement: 0.0026383032514290376


Epoch 14/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.23it/s, loss=0.0252]


Epoch 14, Avg Loss: 0.02569798430935903, Improvement: 0.0019656882028688083


Epoch 15/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.35it/s, loss=0.023]


Epoch 15, Avg Loss: 0.023842876438390125, Improvement: 0.0018551078709689054


Epoch 16/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.18it/s, loss=0.0208]


Epoch 16, Avg Loss: 0.022837365384806286, Improvement: 0.0010055110535838387


Epoch 17/20: 100%|█████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.19it/s, loss=0.022]


Epoch 17, Avg Loss: 0.021583566272800617, Improvement: 0.0012537991120056672


Epoch 18/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 14.32it/s, loss=0.0194]


Epoch 18, Avg Loss: 0.020147095519033344, Improvement: 0.0014364707537672737


Epoch 19/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.90it/s, loss=0.0194]


Epoch 19, Avg Loss: 0.019093484363772652, Improvement: 0.0010536111552606928


Epoch 20/20: 100%|████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.89it/s, loss=0.0187]

Epoch 20, Avg Loss: 0.018734653903679413, Improvement: 0.0003588304600932381





In [11]:
# Evaluation

def evaluation(model_n, model_d, dataloader):
    model_n.eval()
    model_d.eval()

    loss_fn_n = log_cosh_loss
    loss_fn_d = nn.CosineEmbeddingLoss()

    loss_n = 0.0
    loss_d = 0.0
    accuracy = 0.0
    correct_predictions = 0
    pred_indices = {}
    size = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluation", leave=True, position=0):
            batch_input_ids, batch_attention_masks, batch_word_indices, \
            batch_sense_indices, batch_lemma_indices, batch_idices = [b.to(device) for b in batch]

            # Forward pass
            norms = model_n(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                            word_index=batch_word_indices).squeeze()
            directions = model_d(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                                 word_index=batch_word_indices).squeeze()

            target_norms = nball_norms[batch_sense_indices]
            target_directions = nball_embeddings[batch_sense_indices]

            labels = torch.ones(directions.size(0), device=device)
            loss_n_batch = loss_fn_n(norms, target_norms * 1e-5)
            loss_d_batch = loss_fn_d(directions, target_directions, labels)
            loss_n += loss_n_batch.item()
            loss_d += loss_d_batch.item()

            # Predict
            # Get candidate group norms and directions based on lemma_pair mapping
            candidate_indices = [lemma_pair[l.item()] for l in batch_lemma_indices]
            candidate_norms = [nball_norms[idx] * 1e-5 for idx in candidate_indices]
            candidate_directions = [nball_embeddings[idx] for idx in candidate_indices]
            output_embeddings = directions * norms.unsqueeze(-1)
            # Calculate distances and predictions
            for i in range(len(output_embeddings)):
                size += 1
                output_emb = output_embeddings[i]
                distances = [F.pairwise_distance(output_emb.unsqueeze(0), 
                                                  (cand_dir * cand_norm.unsqueeze(-1)).unsqueeze(0)).item()
                             for cand_dir, cand_norm in zip(candidate_directions[i], candidate_norms[i])]
                min_distance_idx = distances.index(min(distances))
                predicted_index = candidate_indices[i][min_distance_idx]
                pred_indices[batch_idices[i].item()] = predicted_index
                if predicted_index == batch_sense_indices[i].item():
                    correct_predictions += 1
                

        total_batches = len(dataloader)
        loss_n /= total_batches
        loss_d /= total_batches
        accuracy = correct_predictions / size

    return loss_n, loss_d, accuracy, pred_indices


In [21]:
loss_n, loss_d, accurancy, pred_indices = evaluation(model_n=model_norm, \
                                                     model_d=model_d,\
                                                     dataloader=dataloader_train)

print(f"Average loss of norm:{loss_n}, average loss of dorection:{loss_d}")
print(f"Accurancy:{accurancy*100}%")
train_semcor['pred_sense_idx'] = train_semcor.index.map(pred_indices)
display(train_semcor.head())

Evaluation: 100%|██████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 17.10it/s]


Average loss of norm:0.656003089113669, average loss of dorection:0.0026112972674044695
Accurancy:99.43181818181817%


Unnamed: 0,lemma,word,sentence_text,lemma_idx,formatted_sense_id,sense_idx,sense_group,pred_sense_id,pred_sense_idx
0,man,men,Whereas the eighteenth century had been a time...,131,homo.n.02,132,[132],132,132
1,horse,horse,One wrote : `` [ I am so hungry ] I could eat ...,255,horse.n.01,261,[261],261,261
2,wildcat,wildcat,"She is well-educated and refined , all wildcat...",53,wildcat.n.03,54,[54],54,54
3,lion,Lions,"To find out , we traveled throughout that part...",61,lion.n.01,62,[62],62,62
4,elephant,elephants,The Prince visited the hospital of Operation B...,103,elephant.n.01,104,[104],104,104


In [22]:
count_single_element = (train_semcor['sense_group'].apply(lambda x: len(x) == 1)).sum()
mismatch_count = (train_semcor['sense_idx'] != train_semcor['pred_sense_idx']).sum()
count_multiple_element = set_length-count_single_element

set_length = len(train_semcor)
print(f"Length of dataset:{set_length}")
print(f"Length of have only one element:{count_single_element}")
print(f"Length of mismatch:{mismatch_count}")
print(f"Original accurancy:{count_single_element / set_length *100}%")
print(f"Filtered accurancy:{(count_multiple_element-mismatch_count) /count_multiple_element  *100}%")


Length of dataset:352
Length of have only one element:330
Length of mismatch:2
Original accurancy:93.75%
Filtered accurancy:90.9090909090909%


In [19]:
eval_all, lemma_pair = preprocess_data(eval_all_path, nball)
dataset_eval = nballDataset(eval_all, nball, model_url, max_length)
dataloader_eval = DataLoader(dataset_eval, batch_size, shuffle=True)
loss_n, loss_d, accurancy, pred_indices = evaluation(model_n=model_norm, \
                                                     model_d=model_d,\
                                                     dataloader=dataloader_eval)

eval_all['pred_sense_idx'] = eval_all.index.map(pred_indices)
print(f"Average loss of norm:{loss_n}, average loss of dorection:{loss_d}")
print(f"Accurancy:{accurancy}")
display(eval_all.head())

Tokenizing sentences...
Tokenizing finished.
Calculating word indices...
Tokenizing finished.


Evaluation: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.37it/s]

Average loss of norm:0.5518102645874023, average loss of dorection:0.0034370184876024723
Accurancy:1.0





Unnamed: 0,lemma,word,sentence_text,lemma_idx,formatted_sense_id,sense_idx,sense_group,pred_sense_idx
0,mouse,mice,His doubts stemmed from the fact that several ...,183,mouse.n.01,184,[184],184
1,cat,cats,"The cats are fine , although nervous .",54,cat.n.01,55,[55],55
2,mouse,mouse,The quivers move through my house every few mi...,183,mouse.n.01,184,[184],184
3,cow,cows,After long stretches of this attendant ground ...,99,cow.n.01,310,"[100, 310]",310
4,world,world,But it was seen as an important advance in a n...,130,world.n.08,131,[131],131
