In [1]:
import os
import nltk
from nltk.corpus import wordnet as wn
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import logging
logging.basicConfig(level=logging.WARNING)

In [2]:
# Read the files
def load_tree(wsChildrenFile = None):
    wsChildrenDic = {}
    if os.path.isfile(wsChildrenFile):
        with open(wsChildrenFile, 'r') as chfh:
            for ln in chfh:
                wlst = ln[:-1].split()
                wsChildrenDic[wlst[0]] = wlst[1:]
    return wsChildrenDic

wsChildrenFile = '../data/sample_mammal/wsChildren.txt'
tree = load_tree(wsChildrenFile)

In [3]:
len(tree.keys())

389

In [4]:
def build_parent_map(tree):
    parent_map = {}
    for parent, children in tree.items():
        for child in children:
            parent_map[child] = parent
    return parent_map

# Example usage
parent_map = build_parent_map(tree)

In [5]:
def get_ancestors(node, parent_map):
    ancestors = []
    while node in parent_map:
        node = parent_map[node]
        ancestors.append(node)
    return ancestors

def determine_hyper_relationships(tree, parent_map):
    relationship_dict = {}
    nodes = list(tree.keys())

    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            node1, node2 = nodes[i], nodes[j]
            if node1 in parent_map and parent_map[node1] == node2:
                relationship_dict[(node1, node2)] = ('child', 0)
                relationship_dict[(node2, node1)] = ('parent', 0)
            elif node2 in parent_map and parent_map[node2] == node1:
                relationship_dict[(node1, node2)] = ('parent', 0)
                relationship_dict[(node2, node1)] = ('child', 0)
            elif node1 in parent_map and node2 in parent_map and parent_map[node1] == parent_map[node2]:
                relationship_dict[(node1, node2)] = ('sibling', 0)
                relationship_dict[(node2, node1)] = ('sibling', 0)
            else:
                ancestors1 = get_ancestors(node1, parent_map)
                ancestors2 = get_ancestors(node2, parent_map)
                
                common_ancestor = None
                for k, ancestor in enumerate(ancestors1):
                    if ancestor in ancestors2:
                        common_ancestor = ancestor
                        break
                
                if common_ancestor:
                    k2 = ancestors2.index(common_ancestor)
                    relationship_dict[(node1, node2)] = ('hyper_sibling', k + 1)
                    relationship_dict[(node2, node1)] = ('hyper_sibling', k + 1)
                else:
                    if node1 in ancestors2:
                        relationship_dict[(node1, node2)] = ('hyper_child', ancestors2.index(node1) + 1)
                        relationship_dict[(node2, node1)] = ('hyper_parent', ancestors2.index(node1) + 1)
                    elif node2 in ancestors1:
                        relationship_dict[(node1, node2)] = ('hyper_parent', ancestors1.index(node2) + 1)
                        relationship_dict[(node2, node1)] = ('hyper_child', ancestors1.index(node2) + 1)
                    else:
                        relationship_dict[(node1, node2)] = ('unrelated', -1)
                        relationship_dict[(node2, node1)] = ('unrelated', -1)

    return relationship_dict

# Example usage
relationship_dict = determine_hyper_relationships(tree, parent_map)
print(len(relationship_dict))

150932


In [6]:
def print_unrelated_nodes(relationship_dict):
    unrelated_pairs = [pair for pair, relationship in relationship_dict.items() if relationship[0] == 'unrelated']
    for pair in unrelated_pairs:
        print(f"Unrelated nodes: {pair[0]} and {pair[1]}")
    print(f"{len(unrelated_pairs)} unrelated pair printed")

print_unrelated_nodes(relationship_dict)

0 unrelated pair printed


In [7]:
wn.synset("entity.n.01").definition()

'that which is perceived or known or inferred to have its own distinct existence (living or nonliving)'

In [8]:
def get_explanation(synset_name):
    if synset_name == "*root*":
        return "root node"
    synset = wn.synset(synset_name)
    return synset.definition()


def create_training_data(relationship_dict):
    training_data = []
    for (lemma1, lemma2), (relationship, hyper_value) in relationship_dict.items():
        explanation1 = get_explanation(lemma1)
        explanation2 = get_explanation(lemma2)
        sentence = f"[CLS]{lemma1}:{explanation1}[SEP]{lemma2}:{explanation2}"
        label = relationship
        training_data.append((sentence, label))
    return training_data


training_data = create_training_data(relationship_dict)


In [9]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class SynsetDataset(Dataset):
    def __init__(self, data, tokenizer, label_map, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.label_map = label_map
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence, label = self.data[idx]
        inputs = self.tokenizer(sentence, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt")
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        inputs['labels'] = torch.tensor(self.label_map[label], dtype=torch.long)
        return inputs


# Define the label map
label_map = {
    'sibling': 0,
    'parent': 1,
    'child': 2,
    'hyper_sibling': 3,
    'hyper_parent': 4,
    'hyper_child': 5,
    'unrelated': 6
}

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create a dataset
dataset = SynsetDataset(training_data, tokenizer, label_map)

# Split into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [10]:

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_map))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    per_device_train_batch_size=64,   # batch size for training
    per_device_eval_batch_size=64,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

# Train the model
trainer.train()

Step,Training Loss
10,0.2188
20,0.1992
30,0.1319
40,0.1323
50,0.1704
60,0.1423
70,0.1693
80,0.1467
90,0.1648
100,0.1249


TrainOutput(global_step=1887, training_loss=0.07833661152435853, metrics={'train_runtime': 2033.1688, 'train_samples_per_second': 59.388, 'train_steps_per_second': 0.928, 'total_flos': 7942692650016000.0, 'train_loss': 0.07833661152435853, 'epoch': 1.0})

In [14]:
eval_result = trainer.evaluate()
print(eval_result)

{'eval_loss': 0.03326768800616264, 'eval_runtime': 221.5223, 'eval_samples_per_second': 136.271, 'eval_steps_per_second': 2.131, 'epoch': 1.0}


In [15]:
print(len(train_dataset))
print(len(val_dataset))
print(len(dataset))

120745
30187
150932


In [17]:
print(389*389-389)

150932
