### KG-BERT
https://arxiv.org/pdf/1909.03193v2.pdf

### WN18RR
https://arxiv.org/pdf/1707.01476.pdf

In [1]:
%%capture
!pip install transformers datasets nltk

In [2]:
import numpy as np
import pandas as pd
import datasets
import nltk
from nltk.corpus import wordnet as wn
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding, 
    TrainingArguments, 
    Trainer,
    )

In [3]:
# load wordnet data
nltk.download("omw-1.4")
nltk.download("wordnet")

[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

https://github.com/villmow/datasets_knowledge_embedding

In [4]:
data_files = dict(
    train="https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/text/train.txt",
    dev="https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/text/valid.txt",
    test="https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/text/test.txt"
    )

In [5]:
!wget "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/text/train.txt"
!head -n 5 train.txt

--2022-11-04 11:15:21--  https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/text/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4091114 (3.9M) [text/plain]
Saving to: ‘train.txt.1’


2022-11-04 11:15:21 (301 MB/s) - ‘train.txt.1’ saved [4091114/4091114]

land_reform.n.01	_hypernym	reform.n.01
cover.v.01	_derivationally_related_form	covering.n.02
botany.n.02	_derivationally_related_form	botanize.v.01
kamet.n.01	_instance_hypernym	mountain_peak.n.01
question.n.01	_derivationally_related_form	ask.v.01


In [6]:
# designed to be used with Dataset or DatasetDict `map` method
def add_definitions(ds, in_name, out_name, def_dict):
  senses = ds[in_name]
  definitions = []
  for sense in senses:
    definitions.append(def_dict[sense])
  return {out_name: definitions}

In [7]:
# the whole preprocessing
def preprocess_splits(files_dict):
  dict_to_init = dict()
  unique_senses = set()
  for key, value in files_dict.items():
    # load data from urls
    cur_data = pd.read_csv(value, sep="\t", names=["sense1", "relation", "sense2"])
    dict_to_init[key] = datasets.Dataset.from_pandas(cur_data)
    # save unique senses that appear in this split
    for col in ["sense1", "sense2"]:
      unique_senses.update(cur_data[col].unique())
  # retrieve definitions for every unique sense
  def_dict = {sense: wn.synset(sense).definition() for sense in unique_senses}
  # initialize DatasetDict class from gathered data
  dataset = datasets.DatasetDict(dict_to_init)
  # replace `str` label names with corresponding `int` values
  dataset = dataset.class_encode_column("relation")
  # add appropriate definitions for every sense in the dataset splits
  for i in range(2):
    dataset = dataset.map(
        add_definitions,
        fn_kwargs=dict(
            in_name=f"sense{i + 1}",
            out_name=f"def{i + 1}",
            def_dict=def_dict,
            ),
        batched=True,
        )
  return dataset

In [8]:
dataset = preprocess_splits(data_files)

Casting to class labels:   0%|          | 0/87 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/9 [00:00<?, ?ba/s]

Casting to class labels:   0%|          | 0/4 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting to class labels:   0%|          | 0/4 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/87 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/87 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [9]:
MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [10]:
def tokenize(dataset, tokenizer):
  return tokenizer(dataset["def1"], dataset["def2"])

In [11]:
dataset = dataset.map(tokenize, batched=True, fn_kwargs=dict(tokenizer=tokenizer))

  0%|          | 0/87 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [12]:
# the model expects a column named `label` specifically
dataset = dataset.rename_column("relation", "label")

In [14]:
# this is not strictly necessary, but it helps with readability
# of the final model's configuration
label_names = dataset["train"].features["label"].names
num_labels = len(label_names)
label2id = {key: value for key, value in zip(label_names, range(num_labels))}
id2label = {value: key for key, value in label2id.items()}

# initialize transformer model
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    label2id=label2id,
    id2label=id2label,
    num_labels=num_labels,
    )

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier

In [16]:
label_names

['_also_see',
 '_derivationally_related_form',
 '_has_part',
 '_hypernym',
 '_instance_hypernym',
 '_member_meronym',
 '_member_of_domain_region',
 '_member_of_domain_usage',
 '_similar_to',
 '_synset_domain_topic_of',
 '_verb_group']

In [17]:
dataset["train"].to_pandas().head()

Unnamed: 0,sense1,label,sense2,def1,def2,input_ids,attention_mask
0,land_reform.n.01,3,reform.n.01,a redistribution of agricultural land (especia...,a change for the better as a result of correct...,"[101, 1037, 25707, 1997, 4910, 2455, 1006, 292...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,cover.v.01,1,covering.n.02,provide with a covering or cause to be covered,an artifact that covers something else (usuall...,"[101, 3073, 2007, 1037, 5266, 2030, 3426, 2000...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,botany.n.02,1,botanize.v.01,the branch of biology that studies plants,collect and study plants,"[101, 1996, 3589, 1997, 7366, 2008, 2913, 4264...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
3,kamet.n.01,4,mountain_peak.n.01,a mountain in the Himalayas in northern India ...,the summit of a mountain,"[101, 1037, 3137, 1999, 1996, 26779, 1999, 264...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,question.n.01,1,ask.v.01,an instance of questioning,inquire about,"[101, 2019, 6013, 1997, 11242, 102, 1999, 1554...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"


https://huggingface.co/transformers/v4.8.1/_modules/transformers/data/data_collator.html#DataCollatorWithPadding

In [19]:
# pads the batches and transforms them to pytorch format
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [20]:
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

Graph completion problemes are commonly viewed as ranking, rather than classification problems. Accordingly, ranking metrics, such as **Mean Reciprocal Rank** or **Hits@k** are often used. However, in this particular case these are not really relevant, since we have only one correct answer for each example. We're going to use **Accuracy**, instead.

In [21]:
# the number of unique (head, tail) pairs is equal to the length of the dataset
# meaning that there are no pairs that have more than one relation between them
len(set(zip(dataset["dev"]["sense1"], dataset["dev"]["sense2"]))) == len(dataset["dev"])

True

In [22]:
# naturally, this also the case for test split
len(set(zip(dataset["test"]["sense1"], dataset["test"]["sense2"]))) == len(dataset["test"])

True

Let's make sure that our trained model actually learns something useful by evaluating several basic baselines.

In [23]:
# predictions before any training
preds, true_labels, _ = trainer.predict(dataset["dev"])
predicted_labels = np.argmax(preds, axis=-1)
base_accuracy = np.sum(predicted_labels == true_labels) / true_labels.shape[0]
print(f"Accuracy without fine-tuning is {base_accuracy * 100:.4f} %")

The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: sense1, def2, def1, sense2. If sense1, def2, def1, sense2 are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3034
  Batch size = 64
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Accuracy without fine-tuning is 1.3184 %


In [24]:
# random predictions with uniform distribution sampling
uniform_random_preds = np.random.randint(0, len(dataset["dev"].features["label"].names), len(dataset["dev"]))
uniform_random_accuracy = np.sum(uniform_random_preds == true_labels) / true_labels.shape[0]
print(f"Accuracy of uniform random prediction is  {uniform_random_accuracy * 100:.4f} %")

Accuracy of uniform random prediction is  9.3276 %


In [25]:
# random predictions with weighted sampling
weights = [value / len(dataset["dev"]) for value in dataset["dev"].to_pandas()["label"].value_counts().sort_index().tolist()]
weighted_random_preds = np.random.choice(a=len(dataset["dev"].features["label"].names), size=len(dataset["dev"]), p=weights)
weighted_random_accuracy = np.sum(weighted_random_preds == true_labels) / true_labels.shape[0]
print(f"Accuracy of weighted random prediction is  {weighted_random_accuracy * 100:.4f} %")

Accuracy of weighted random prediction is  28.6750 %


In [26]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: sense1, def2, def1, sense2. If sense1, def2, def1, sense2 are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 86835
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 1357
  Number of trainable parameters = 66961931


Step,Training Loss
500,0.7312
1000,0.2985


Saving model checkpoint to ./results/checkpoint-500
Configuration saved in ./results/checkpoint-500/config.json
Model weights saved in ./results/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-500/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-500/special_tokens_map.json
Saving model checkpoint to ./results/checkpoint-1000
Configuration saved in ./results/checkpoint-1000/config.json
Model weights saved in ./results/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-1000/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-1000/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=1357, training_loss=0.44539203236185637, metrics={'train_runtime': 515.0223, 'train_samples_per_second': 168.604, 'train_steps_per_second': 2.635, 'total_flos': 1688279187416136.0, 'train_loss': 0.44539203236185637, 'epoch': 1.0})

In [27]:
preds, true_labels, _ = trainer.predict(dataset["dev"])
predicted_labels = np.argmax(preds, axis=-1)
finetuned_accuracy = np.sum(predicted_labels == true_labels) / true_labels.shape[0]
print(f"Accuracy after fine-tuning is {finetuned_accuracy * 100:.4f} %")

The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: sense1, def2, def1, sense2. If sense1, def2, def1, sense2 are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3034
  Batch size = 64


Accuracy after fine-tuning is 93.3421 %


In [28]:
preds, true_labels, _ = trainer.predict(dataset["test"])
predicted_labels = np.argmax(preds, axis=-1)
finetuned_accuracy = np.sum(predicted_labels == true_labels) / true_labels.shape[0]
print(f"Final test accuracy {finetuned_accuracy * 100:.4f} %")

The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: sense1, def2, def1, sense2. If sense1, def2, def1, sense2 are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3134
  Batch size = 64


Final test accuracy 93.7141 %


**NB!** There is a number of caveats that you need to be aware of. First of all, there is no "no relation" label in the dataset. In other words, this model has no way of identifying pairs of nodes that have no edge between them. To be able to account for that, you would need to add synthetic negative examples to the dataset.

Secondly, this model doesn't take structural properties of the graph into account in any way. This may a significant downside, depending on the data.