In [2]:
import os
import sys
sys.path.append('../src')
from util import load_balls, load_ball_dims, load_semcor_data, format_sense_id, get_eval_pair
from dataset import nballDataset
from model import nballBertNorm, nballBertDirection

In [3]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torch.nn.functional as F

In [4]:
# Read all the roots name
directory_path = '../data/entity_multi'
# Initialize an empty list to store the extracted file names
roots = []

# Loop through each file in the specified directory
for file in os.listdir(directory_path):
    # Check if the file name matches the format xxx_nball.txt
    if file.endswith('_nball.txt'):
        # Extract the xxx part and add it to the list
        root = file.split('_nball.txt')[0]
        roots.append(root)

print(f"Numbers of the roots:{len(roots)}")

Numbers of the roots:81


In [5]:
# Read original Tree
def read_tree(wsChildrenFile = None):
    wsChildrenDic = dict()
    if os.path.isfile(wsChildrenFile):
        with open(wsChildrenFile, 'r') as chfh:
            for ln in chfh:
                wlst = ln[:-1].split()
                wsChildrenDic[wlst[0]] = wlst[1:]
    return wsChildrenDic

# Read files
wsChildrenFile = "D:\\OtherProject\\dart4wsd\\data\\entity_multi\\wsChildren.txt"
wsChildrenDic = read_tree(wsChildrenFile)

In [6]:
# Read training and eval corpra
# Load data
train_semcor_path = {
    "xml":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Training_Corpora/Semcor/semcor.gold.key.txt',}

eval_all_path = {
    "xml":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.data.xml',
    "gold_key":'../data/WSD_Evaluation_Framework/Evaluation_Datasets/ALL/ALL.gold.key.txt',}

training_corpra = load_semcor_data(train_semcor_path)

In [7]:
# Formate the sense id in training corpra
sense_id_cache = {}
training_corpra['formatted_sense_id'] = training_corpra['sense_id'].apply(lambda x: format_sense_id(x, sense_id_cache))

In [8]:
# Calculate roots level
from collections import deque

def calculate_node_levels(tree, roots):
    # Initialize a dictionary to store the levels
    root_levels = {}

    # Perform a BFS or DFS from each subroot to calculate levels
    for root in roots:
        # Initialize a queue for BFS with tuples (node, level)
        queue = deque([(root, 1)])  # Start at level 1 for the root

        while queue:
            current_node, current_level = queue.popleft()

            # If the current node is already in node_levels, continue
            if current_node in root_levels:
                continue

            # Assign the level to the current node
            root_levels[current_node] = current_level

            # Add children nodes to the queue
            for child in tree.get(current_node, []):
                queue.append((child, current_level + 1))

    return root_levels

root_levels = calculate_node_levels(wsChildrenDic, roots)
root_levels["entity.n.01"] = 0

sorted_nodes = sorted(root_levels.items(), key=lambda x: x[1])
for node, level in sorted_nodes:
    print(f"Node: {node}, Level: {level}")


Node: entity.n.01, Level: 0
Node: ability.n.02, Level: 1
Node: abstraction.n.06, Level: 1
Node: adult.n.01, Level: 1
Node: advocate.n.01, Level: 1
Node: amerindian.n.01, Level: 1
Node: animal.n.01, Level: 1
Node: artifact.n.01, Level: 1
Node: asian.n.01, Level: 1
Node: capitalist.n.02, Level: 1
Node: communicator.n.01, Level: 1
Node: contestant.n.01, Level: 1
Node: creator.n.02, Level: 1
Node: entertainer.n.01, Level: 1
Node: aptitude.n.01, Level: 2
Node: art.n.03, Level: 2
Node: bilingualism.n.01, Level: 2
Node: capacity.n.08, Level: 2
Node: creativity.n.01, Level: 2
Node: faculty.n.01, Level: 2
Node: hand.n.04, Level: 2
Node: intelligence.n.01, Level: 2
Node: know-how.n.01, Level: 2
Node: leadership.n.04, Level: 2
Node: originality.n.01, Level: 2
Node: skill.n.01, Level: 2
Node: skill.n.02, Level: 2
Node: attribute.n.02, Level: 2
Node: cognition.n.01, Level: 2
Node: communication.n.02, Level: 2
Node: event.n.01, Level: 2
Node: group.n.01, Level: 2
Node: measure.n.02, Level: 2
Node: m

In [23]:
def find_subtree(tree, root, roots):
    """
    Find the nodes in the subtree of `root`, but do not expand the other roots.
    """
    subtree_nodes = set()
    queue = deque([root])

    while queue:
        current_node = queue.popleft()
        subtree_nodes.add(current_node)

        # If we encounter another root, don't expand it
        if current_node in roots and current_node != root:
            continue

        # Add children to the queue
        for child in tree.get(current_node, []):
            if child not in subtree_nodes:
                queue.append(child)
    # delete root from subtree_nodes
    subtree_nodes.discard(root)
    
    return subtree_nodes


def find_subtree_all(tree, root):
    """
    Find the nodes in the entire subtree of the given root, expanding all nodes including roots.
    """
    subtree_nodes = set()
    queue = deque([root])

    while queue:
        current_node = queue.popleft()
        subtree_nodes.add(current_node)

        # Add children to the queue
        for child in tree.get(current_node, []):
            if child not in subtree_nodes:
                queue.append(child)

    # delete root from subtree_nodes
    subtree_nodes.discard(root)
    
    return subtree_nodes


# Test
test_root = 'mammal.n.01'
subnodes_entity = find_subtree(wsChildrenDic, test_root, roots)
subroots = subnodes_entity.intersection(roots)
print(len(subnodes_entity))
print(len(subroots))
print(subroots)

386
0
set()


In [24]:
def filter_corpra(corpra, tree, root, roots):
    """
    Filter the corpora DataFrame for a specific root, replacing formatted_sense_id with the subroot.
    """
    # First, find the subtree nodes from the root in the tree (without expanding other roots)
    subtree_nodes = find_subtree(tree, root, roots)
    # Second, for the roots in the subtree nodes, replace all the nodes in their subtree with the root
    subroots = subtree_nodes.intersection(roots)
    for subroot in subroots:
        subtree_nodes_of_subroot = find_subtree_all(tree, subroot)
        # Replace:
        corpra.loc[corpra['formatted_sense_id'].isin(subtree_nodes_of_subroot), 'formatted_sense_id'] = subroot
    
    # Third, filter the corpora that "formatted_sense_id" in subtree_nodes
    filtered_corpra = corpra[corpra['formatted_sense_id'].isin(subtree_nodes)]
    filtered_corpra.reset_index(drop=True, inplace=True)

    return filtered_corpra

# Test
filtered_corpra = filter_corpra(training_corpra.copy(), wsChildrenDic, test_root, roots)
print(f"corpra length:{len(training_corpra)}, filtered corpra length:{len(filtered_corpra)}")

display(training_corpra[training_corpra['pos']=='NOUN'].head(10))
display(filtered_corpra.head(10))

corpra length:226036, filtered corpra length:338


Unnamed: 0,sentence_id,instance_id,lemma,pos,word,sentence_text,sense_id,formatted_sense_id
3,d000.s000,d000.s000.t003,objective,NOUN,objectives,How long has it been since you reviewed the ob...,objective%1:09:00::,aim.n.02
4,d000.s000,d000.s000.t004,benefit,NOUN,benefit,How long has it been since you reviewed the ob...,benefit%1:21:00::,benefit.n.01
5,d000.s000,d000.s000.t005,service,NOUN,service,How long has it been since you reviewed the ob...,service%1:04:07::,service.n.05
6,d000.s000,d000.s000.t006,program,NOUN,program,How long has it been since you reviewed the ob...,program%1:09:01::,program.n.02
9,d000.s001,d000.s001.t002,giveaway,NOUN,giveaway,Have you permitted it to become a giveaway pro...,giveaway%1:21:00::,giveaway.n.01
10,d000.s001,d000.s001.t003,program,NOUN,program,Have you permitted it to become a giveaway pro...,program%1:09:01::,program.n.02
13,d000.s001,d000.s001.t006,goal,NOUN,goal,Have you permitted it to become a giveaway pro...,goal%1:09:00::,goal.n.01
15,d000.s001,d000.s001.t008,employee,NOUN,employee,Have you permitted it to become a giveaway pro...,employee%1:18:00::,employee.n.01
16,d000.s001,d000.s001.t009,morale,NOUN,morale,Have you permitted it to become a giveaway pro...,morale%1:26:00::,morale.n.01
19,d000.s001,d000.s001.t012,productivity,NOUN,productivity,Have you permitted it to become a giveaway pro...,productivity%1:07:00::,productiveness.n.01


Unnamed: 0,sentence_id,instance_id,lemma,pos,word,sentence_text,sense_id,formatted_sense_id
0,d003.s049,d003.s049.t010,man,NOUN,men,Whereas the eighteenth century had been a time...,man%1:05:01::,homo.n.02
1,d006.s007,d006.s007.t007,horse,NOUN,horse,One wrote : `` [ I am so hungry ] I could eat ...,horse%1:05:00::,horse.n.01
2,d006.s013,d006.s013.t003,wildcat,NOUN,wildcat,"She is well-educated and refined , all wildcat...",wildcat%1:05:00::,wildcat.n.03
3,d007.s061,d007.s061.t017,lion,NOUN,Lions,"To find out , we traveled throughout that part...",lion%1:05:00::,lion.n.01
4,d007.s118,d007.s118.t012,elephant,NOUN,elephants,The Prince visited the hospital of Operation B...,elephant%1:05:00::,elephant.n.01
5,d011.s002,d011.s002.t009,bronc,NOUN,broncs,Although they were forced to maintain a sharpe...,bronc%1:05:00::,bronco.n.01
6,d011.s092,d011.s092.t002,horse,NOUN,horses,Neither spoke till they reached their horses .,horse%1:05:00::,horse.n.01
7,d011.s112,d011.s112.t003,bronc,NOUN,bronc,"Giving the other a dark look , he hauled his b...",bronc%1:05:00::,bronco.n.01
8,d011.s130,d011.s130.t002,bronc,NOUN,bronc,Clapping spurs to the bronc he set off at a sh...,bronc%1:05:00::,bronco.n.01
9,d011.s145,d011.s145.t001,pony,NOUN,pony,"Leading his pony , he hurried that way , not r...",pony%1:05:02::,pony.n.01


In [11]:
# Check
sub_all_abs = find_subtree_all(wsChildrenDic, "abstraction.n.06")
check = ["aim.n.02", "benefit.n.01", "service.n.05", "program.n.02", "giveaway.n.01", "goal.n.01", "morale.n.01", "productiveness.n.01"]
for c in check:
    print(c in sub_all_abs)

True
True
True
True
True
True
True
True


In [12]:
def preprocess_data(corpra, nball):
    # For corpra, it has been filtered and replaced within the subtree of the root
    sense_labels = list(nball.keys())
    sense_index = {sense: idx for idx, sense in enumerate(sense_labels)}
    lemma_labels, lemma_pair, eval_pair = get_eval_pair(sense_index)
    lemma_index = {lemma: idx for idx, lemma in enumerate(lemma_labels)}
    
    corpra['sense_idx'] = corpra['formatted_sense_id'].map(sense_index)
    # Problem about it
    # data_merged['lemma_idx'] = data_merged['lemma'].map(lemma_index)
    corpra['lemma_idx'] = corpra['formatted_sense_id'].apply(lambda x: lemma_index[x.split('.')[0]])
    corpra['sense_group'] = corpra['sense_idx'].map(eval_pair)
    # We keep those columns for now
    keys_to_keep = ['lemma', 'word', 'sentence_text', 'lemma_idx', 'formatted_sense_id', 'sense_idx', 'sense_group']
    corpra = corpra[keys_to_keep]

    return corpra, lemma_pair
    

In [32]:
def train(model, dataloader, optimizer, scheduler, num_epochs, mode, nball_target):
    last_loss = 0
    if mode == "norm":
        loss_fn = log_cosh_loss
    elif mode == "direction":
        loss_fn = nn.CosineEmbeddingLoss()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, position=0)
        for batch in progress_bar:
            # Unpack batch and send to device
            batch_input_ids, batch_attention_masks, batch_word_indices, \
            batch_sense_indices, _, _ = [b.to(device) for b in batch]

            optimizer.zero_grad()
        
            # Forward pass
            output = model(input_ids=batch_input_ids, attention_mask=batch_attention_masks, \
                            word_index=batch_word_indices).squeeze()

            # print(output)
            if mode == "norm":
                batch_norms = nball_target[batch_sense_indices]
                loss = loss_fn(output, batch_norms * 1e-10)
            elif mode == "direction":
                labels = torch.ones(output.size(0), device=device)
                batch_senses = nball_target[batch_sense_indices]
                loss = loss_fn(output, batch_senses, labels)
                
            loss.backward()
            optimizer.step()
        
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        # Logging average metrics per epoch
        avg_loss = total_loss / len(dataloader)
        improvement = (last_loss - total_loss) / len(dataloader) if last_loss != 0 else 0
        print(f'Epoch {epoch + 1}, Avg Loss: {avg_loss}, Improvement: {improvement}')
        last_loss = total_loss

        scheduler.step()

In [14]:
def log_cosh_loss(norm_u, norm_v):
    return torch.log(torch.cosh(norm_u - norm_v)).mean()

In [19]:
# train one tree
def train_one_tree(model_norm, root, training_data, device):
    # Preprocess data
    nball_path = '../data/entity_multi/' + root + "_nball.txt"
    nball = load_balls(nball_path)
    train_semcor, lemma_pair = preprocess_data(training_data, nball)

    
    # Load nball embeddings:
    sense_labels = list(nball.keys())
    nball_embeddings = [nball[label].center for label in sense_labels]
    nball_norms = [nball[label].distance for label in sense_labels]
    nball_radius = [nball[label].radius for label in sense_labels]

    
    # Data checking:
    # Sorting the lists
    nball_norms_sorted = sorted(nball_norms)
    nball_radius_sorted = sorted(nball_radius)
    # Getting the smallest 10 from each list
    smallest_10_norms = nball_norms_sorted[:10]
    smallest_10_radius = nball_radius_sorted[:10]
    # Getting the largest 10 from each list
    largest_10_norms = nball_norms_sorted[-10:]
    largest_10_radius = nball_radius_sorted[-10:]
    # Printing the results
    print("Smallest 10 elements in nball_norms:", smallest_10_norms)
    print("Largest 10 elements in nball_norms:", largest_10_norms)
    print("Smallest 10 elements in nball_radius:", smallest_10_radius)
    print("Largest 10 elements in nball_radius:", largest_10_radius)


    # Wrap as tensor
    nball_embeddings = torch.tensor(np.array(nball_embeddings), dtype=torch.float32).to(device)
    nball_norms = torch.tensor(np.array(nball_norms), dtype=torch.float32).to(device)
    nball_radius = torch.tensor(np.array(nball_radius), dtype=torch.float32).to(device)

    # Set dataset
    dataset_train = nballDataset(train_semcor, nball.keys(), model_url, max_length)
    dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)

    
    # Call train loop
    loss_fn_norm = log_cosh_loss
    optimizer_norm = optim.Adam([{'params': model_norm.parameters()}], lr=2e-5)
    scheduler_norm = StepLR(optimizer_norm, step_size=50, gamma=0.1)
    # Train_norm
    train(model=model_norm, \
          dataloader=dataloader_train,\
          optimizer=optimizer_norm,\
          scheduler=scheduler_norm,\
          num_epochs=2, \
          mode="norm",
          nball_target=nball_norms)
    
    # optimizer_d = optim.Adam([{'params': model_d.parameters()}], lr=2e-5)
    # scheduler_d = StepLR(optimizer_d, step_size=50, gamma=0.1)
    # # Train direction
    # train(model=model_d, \
    #       dataloader=dataloader_train,\
    #       optimizer=optimizer_d,\
    #       scheduler=scheduler_d,\
    #       num_epochs=20, \
    #       mode="direction")

    # return model_norm, model_d


In [16]:
# Get dimensions of the balls
directory_path = '../data/entity_multi'
# Initialize an empty list to store the extracted file names
ball_dims = {}

# Loop through each file in the specified directory
for root in roots:
    nball_path = '../data/entity_multi/' + root + "_nball.txt"
    ball_dims[root] = load_ball_dims(nball_path)


print(ball_dims)

{'ability.n.02': 157, 'abstraction.n.06': 155, 'act.n.02': 159, 'action.n.01': 159, 'activity.n.01': 160, 'adult.n.01': 158, 'advocate.n.01': 155, 'amerindian.n.01': 156, 'animal.n.01': 158, 'artifact.n.01': 159, 'asian.n.01': 155, 'attribute.n.02': 159, 'capitalist.n.02': 158, 'change.n.03': 158, 'chemical.n.01': 159, 'chordate.n.01': 154, 'cognition.n.01': 159, 'communication.n.02': 159, 'communicator.n.01': 158, 'condition.n.01': 158, 'content.n.05': 159, 'contestant.n.01': 157, 'covering.n.02': 158, 'creation.n.02': 159, 'creator.n.02': 157, 'device.n.01': 159, 'discipline.n.01': 157, 'diversion.n.01': 158, 'entertainer.n.01': 158, 'entity.n.01': 158, 'event.n.01': 160, 'expert.n.01': 159, 'female.n.02': 156, 'gathering.n.01': 156, 'genus.n.02': 153, 'group.n.01': 162, 'happening.n.01': 158, 'implement.n.01': 158, 'inhabitant.n.01': 156, 'instrumentality.n.03': 161, 'intellectual.n.01': 157, 'language.n.01': 158, 'leader.n.01': 157, 'location.n.01': 160, 'mammal.n.01': 160, 'materi

In [34]:
# Test
test_root = 'ability.n.02'
filtered_corpra = filter_corpra(training_corpra.copy(), wsChildrenDic, test_root, roots)

subnodes_entity = find_subtree(wsChildrenDic, test_root, roots)
subroots = subnodes_entity.intersection(roots)
print(len(subnodes_entity))
print(len(subroots))
print(subroots)

369
0
set()


In [35]:
bert_models = {"BERT-Base": ["bert-base-uncased", 768], #768, 12L, 12A
               "BERT-Large": ["bert-large-uncased", 1024], #1024, 24L, 16A
               "BERT-Medium": ["google/bert_uncased_L-8_H-512_A-8", 512], #512, 8L, 8A
               "BERT-Small": ["google/bert_uncased_L-4_H-256_A-4", 256], #256, 4L, 4A
               "BERT-Mini": ["google/bert_uncased_L-4_H-128_A-2", 128],#128, 4L, 2A
               "BERT-Tiny": ["google/bert_uncased_L-2_H-128_A-2", 128]}#128, 2L, 2A

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 512
batch_size = 32
model_url = bert_models["BERT-Small"][0]
output_dim = ball_dims[test_root]
model_norm = nballBertNorm(model_url, output_dim).to(device)
train_one_tree(model_norm, test_root, filtered_corpra, device)

loading balls....
371 balls are loaded

Smallest 10 elements in nball_norms: [1000000.0, 1000000.0, 1000000.0, 1000000.0, 1000000.0, 1000253.7134719943, 1050574.900727731, 1050574.900727731, 1050633.9614014437, 1050952.2580102002]
Largest 10 elements in nball_norms: [41091969.115863234, 41091969.115863234, 41091969.115863234, 42037461.95414666, 42048329.80491032, 42095926.466105804, 42170830.46417026, 42488832.319365725, 43050345.80513795, 47847915.97634286]
Smallest 10 elements in nball_radius: [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.00010505749007277309, 0.00010505749007277309, 0.00011086365831506724, 0.00011721868918547535, 0.00011721868918547535]
Largest 10 elements in nball_radius: [1503900.3829487448, 1516117.3656599831, 1658496.209105272, 2366441.4549796684, 2889502.3717317907, 5162067.0717816185, 23139354.02628052, 26681987.743101574, 42092728.28523759, 46988174.82434979]
Tokenizing sentences...
Tokenizing finished.
Calculating word indices...
Tokenizing finished.


Epoch 1/50: 100%|█████████████████████████████████████████████████████████| 17/17 [00:01<00:00, 12.26it/s, loss=0.0139]


Epoch 1, Avg Loss: 0.015977193053592655, Improvement: 0


Epoch 2/50: 100%|████████████████████████████████████████████████████████| 17/17 [00:01<00:00, 13.09it/s, loss=0.00604]


Epoch 2, Avg Loss: 0.007604574127232327, Improvement: 0.008372618926360327


Epoch 3/50: 100%|████████████████████████████████████████████████████████| 17/17 [00:01<00:00, 13.07it/s, loss=0.00345]


Epoch 3, Avg Loss: 0.004437891204896218, Improvement: 0.0031666829223361086


Epoch 4/50: 100%|████████████████████████████████████████████████████████| 17/17 [00:01<00:00, 13.00it/s, loss=0.00319]


Epoch 4, Avg Loss: 0.003215851659393486, Improvement: 0.0012220395455027327


Epoch 5/50:  35%|████████████████████                                     | 6/17 [00:00<00:00, 11.02it/s, loss=0.00224]


KeyboardInterrupt: 

In [None]:
# Check the dimenstion of all:
dims = []
for root in roots:
    nball_path = '../data/entity_multi/' + root + "_nball.txt"
    dims.append(load_ball_dims(nball_path))

print(max(dims))