## Data

In [None]:
import json
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
!git clone https://github.com/uhh-lt/TextGraphs17-shared-task.git

Cloning into 'TextGraphs17-shared-task'...
remote: Enumerating objects: 484, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 484 (delta 3), reused 2 (delta 2), pack-reused 480[K
Receiving objects: 100% (484/484), 36.63 MiB | 19.18 MiB/s, done.
Resolving deltas: 100% (54/54), done.


In [None]:
data_train = pd.read_csv('TextGraphs17-shared-task/data/tsv/train.tsv', sep='\t')
data_test = pd.read_csv('TextGraphs17-shared-task/data/tsv/test.tsv', sep='\t')

data_train.head()

Unnamed: 0,sample_id,question,questionEntity,answerEntity,groundTruthAnswerEntity,answerEntityId,questionEntityId,groundTruthAnswerEntityId,correct,graph
0,0,Whst is the name of the head of state and high...,Iran,Ruhollah Khomeini's return to Iran,Office of the Supreme Leader of Iran,Q7293530,Q794,Q16045000,False,"{'nodes': [{'type': 'QUESTIONS_ENTITY', 'name_..."
1,1,Whst is the name of the head of state and high...,Iran,Ruhollah Khomeini's letter to Mikhail Gorbachev,Office of the Supreme Leader of Iran,Q5952984,Q794,Q16045000,False,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q417..."
2,2,Whst is the name of the head of state and high...,Iran,Ruhollah Khomeini,Office of the Supreme Leader of Iran,Q38823,Q794,Q16045000,False,"{'nodes': [{'type': 'QUESTIONS_ENTITY', 'name_..."
3,3,Whst is the name of the head of state and high...,Iran,Office of the Supreme Leader of Iran,Office of the Supreme Leader of Iran,Q16045000,Q794,Q16045000,True,"{'nodes': [{'type': 'QUESTIONS_ENTITY', 'name_..."
4,4,Whst is the name of the head of state and high...,Iran,Mohammad Reza Pahlavi and Soraya,Office of the Supreme Leader of Iran,Q63195813,Q794,Q16045000,False,"{'nodes': [{'type': 'QUESTIONS_ENTITY', 'name_..."


In [None]:
data_test.head(20)

Unnamed: 0,sample_id,question,questionEntity,answerEntity,questionEntityId,answerEntityId,graph
0,0,"After publishing A Time to Kill, which book di...",A Time to Kill,A Clash of Kings,Q1213715,Q300370,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
1,1,"After publishing A Time to Kill, which book di...",A Time to Kill,A Feast for Crows,Q1213715,Q1764445,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
2,2,"After publishing A Time to Kill, which book di...",A Time to Kill,Fear and Loathing in Las Vegas,Q1213715,Q772435,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
3,3,"After publishing A Time to Kill, which book di...",A Time to Kill,In Cold Blood,Q1213715,Q1142887,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
4,4,"After publishing A Time to Kill, which book di...",A Time to Kill,Into the Woods,Q1213715,Q1118244,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
5,5,"After publishing A Time to Kill, which book di...",A Time to Kill,Kongenes kamp,Q1213715,Q19377881,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
6,6,"After publishing A Time to Kill, which book di...",A Time to Kill,No Country for Old Men,Q1213715,Q611689,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
7,7,"After publishing A Time to Kill, which book di...",A Time to Kill,No Country for Old Men,Q1213715,Q60411383,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q186..."
8,8,"After publishing A Time to Kill, which book di...",A Time to Kill,Slaughterhouse-Five,Q1213715,Q265954,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."
9,9,"After publishing A Time to Kill, which book di...",A Time to Kill,The Firm,Q1213715,Q1212467,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'..."


In [None]:
N = len(set(data_train['questionEntity']) & set(data_test['questionEntity']))
train = len(set(data_train['questionEntity']))
test = len(set(data_test['questionEntity']))

f'Intersection is {N}, while train_len = {train} and test_len = {test}'

'Intersection is 173, while train_len = 2366 and test_len = 815'

In [None]:
data_train['graph'] = data_train['graph'].apply(eval)
data_test['graph'] = data_test['graph'].apply(eval)

# LLM

### Imports

In [None]:
import os

from tqdm import tqdm
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.utils.data import Dataset, DataLoader

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data_train['texts'] = data_train['question'] + ' ' + data_train['answerEntity']
data_test['texts'] = data_test['question'] + ' ' + data_test['answerEntity']

In [None]:
data_train['labels'] = data_train['correct'].astype(float)

In [None]:
max_length = 0
for text in tqdm(data_train['texts']):
    tokenized_text = tokenizer(text, return_tensors='pt')
    input_ids = tokenized_text['input_ids']
    if input_ids.shape[1] > max_length:
        max_length = input_ids.shape[1]

print("\nMaximum tokenized text length:", max_length)

  0%|          | 0/37672 [00:00<?, ?it/s]


Maximum tokenized text length: 89


### Dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
BATCH_SIZE = 64
max_length = 100

dataset = MyDataset(data_train['texts'], data_train['labels'], tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

### Model

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
distilbert_model = AutoModel.from_pretrained(model_name)

for param in distilbert_model.parameters():
    param.requires_grad = False


class CustomDistilBERT(nn.Module):
    def __init__(self, distilbert_model):
        super(CustomDistilBERT, self).__init__()
        self.distilbert = distilbert_model
        self.hidden_size = distilbert_model.config.hidden_size
        self.linear1 = nn.Linear(self.hidden_size, self.hidden_size * 4)
        self.linear2 = nn.Linear(self.hidden_size * 4, 1)
        self.activ = nn.ReLU()

    def forward(self, input_ids, attention_mask=None):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # Access the last hidden states
        pooled_output = last_hidden_state[:, 0, :]  # Take the [CLS] token representation
        logits = self.linear2(self.activ(self.linear1(pooled_output)))
        return logits

model = CustomDistilBERT(distilbert_model)

for param in model.linear1.parameters():
    param.requires_grad = True
for param in model.linear2.parameters():
    param.requires_grad = True

print("Trainable parameters in the linear layer:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Trainable parameters in the linear layer:
linear1.weight
linear1.bias
linear2.weight
linear2.bias


In [None]:
model.to(device)

CustomDistilBERT(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): 

### Training

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
save_dir = "drive/MyDrive/Deep Learning NLP/saved_models/"
os.makedirs(save_dir, exist_ok=True)

In [None]:
N_EPOCHS = 100
LEARNING_RATE = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
model.train()
loss_list = []
for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        loss = criterion(outputs.squeeze(), labels.float())

        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(dataloader)
    loss_list.append(average_loss)
    print(f'Epoch {epoch + 1}/{N_EPOCHS}, Loss: {average_loss:.4f}')

    if epoch % 10 == 0:
        model_save_path = os.path.join(save_dir, f"model_epoch_{epoch + 1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss_history': loss_list,
        }, model_save_path)

  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 1/100, Loss: 26.4781


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 2/100, Loss: 26.2979


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 3/100, Loss: 26.2122


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 4/100, Loss: 26.1507


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 5/100, Loss: 26.1431


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 6/100, Loss: 26.1022


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 7/100, Loss: 26.1138


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 8/100, Loss: 26.0732


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 9/100, Loss: 26.0388


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 10/100, Loss: 26.0334


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 11/100, Loss: 26.0507


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 12/100, Loss: 26.0135


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 13/100, Loss: 26.0299


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 14/100, Loss: 25.9778


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 15/100, Loss: 26.0028


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 16/100, Loss: 25.9571


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 17/100, Loss: 25.9987


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 18/100, Loss: 26.0042


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 19/100, Loss: 25.9437


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 20/100, Loss: 25.9564


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 21/100, Loss: 26.0269


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 22/100, Loss: 25.9985


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 23/100, Loss: 25.9850


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 24/100, Loss: 26.0318


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 25/100, Loss: 25.9346


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 26/100, Loss: 25.9432


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 27/100, Loss: 25.9430


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 28/100, Loss: 25.9581


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 29/100, Loss: 25.9589


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 30/100, Loss: 25.9547


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 31/100, Loss: 25.8841


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 32/100, Loss: 25.8714


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 33/100, Loss: 25.8567


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 34/100, Loss: 25.8663


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 35/100, Loss: 25.8417


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 36/100, Loss: 25.8293


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 37/100, Loss: 25.8295


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 38/100, Loss: 25.8219


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 39/100, Loss: 25.8048


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 40/100, Loss: 25.7928


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 41/100, Loss: 25.8098


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 42/100, Loss: 25.7646


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 43/100, Loss: 25.8211


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 44/100, Loss: 25.7902


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 45/100, Loss: 25.7471


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 46/100, Loss: 25.7442


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 47/100, Loss: 25.7520


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 48/100, Loss: 25.7428


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 49/100, Loss: 25.7700


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 50/100, Loss: 25.7195


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 51/100, Loss: 25.7481


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 52/100, Loss: 25.7504


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 53/100, Loss: 25.6800


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 54/100, Loss: 25.6952


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 55/100, Loss: 25.6422


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 56/100, Loss: 25.7091


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 57/100, Loss: 25.6420


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 58/100, Loss: 25.6474


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 59/100, Loss: 25.6589


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 60/100, Loss: 25.6507


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 61/100, Loss: 25.6610


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 62/100, Loss: 25.6385


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 63/100, Loss: 25.6276


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 64/100, Loss: 25.6302


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 65/100, Loss: 25.6189


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 66/100, Loss: 25.6820


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 67/100, Loss: 25.5815


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 68/100, Loss: 25.6194


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 69/100, Loss: 25.6347


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 70/100, Loss: 25.5784


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 71/100, Loss: 25.5787


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 72/100, Loss: 25.6151


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 73/100, Loss: 25.6021


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 74/100, Loss: 25.5529


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 75/100, Loss: 25.5684


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 76/100, Loss: 25.5485


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 77/100, Loss: 25.4795


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 78/100, Loss: 25.5512


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 79/100, Loss: 25.4991


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 80/100, Loss: 25.5508


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 81/100, Loss: 25.5392


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 82/100, Loss: 25.5288


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 83/100, Loss: 25.4823


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 84/100, Loss: 25.4948


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 85/100, Loss: 25.4548


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 86/100, Loss: 25.5008


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 87/100, Loss: 25.4491


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 88/100, Loss: 25.4896


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 89/100, Loss: 25.4578


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 90/100, Loss: 25.4407


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 91/100, Loss: 25.4533


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 92/100, Loss: 25.4743


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 93/100, Loss: 25.4416


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 94/100, Loss: 25.3940


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 95/100, Loss: 25.4532


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 96/100, Loss: 25.3788


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 97/100, Loss: 25.3809


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 98/100, Loss: 25.3774


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 99/100, Loss: 25.3819


  0%|          | 0/589 [00:00<?, ?it/s]

Epoch 100/100, Loss: 25.3774


### Inference

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
model_save_path = 'drive/MyDrive/Deep Learning NLP/saved_models/model_epoch_99.pt'

checkpoint = torch.load(model_save_path, map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['model_state_dict'])

model.eval()

In [None]:
data_test.head(10)

Unnamed: 0,sample_id,question,questionEntity,answerEntity,questionEntityId,answerEntityId,graph,texts
0,0,"After publishing A Time to Kill, which book di...",A Time to Kill,A Clash of Kings,Q1213715,Q300370,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
1,1,"After publishing A Time to Kill, which book di...",A Time to Kill,A Feast for Crows,Q1213715,Q1764445,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
2,2,"After publishing A Time to Kill, which book di...",A Time to Kill,Fear and Loathing in Las Vegas,Q1213715,Q772435,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
3,3,"After publishing A Time to Kill, which book di...",A Time to Kill,In Cold Blood,Q1213715,Q1142887,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
4,4,"After publishing A Time to Kill, which book di...",A Time to Kill,Into the Woods,Q1213715,Q1118244,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
5,5,"After publishing A Time to Kill, which book di...",A Time to Kill,Kongenes kamp,Q1213715,Q19377881,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
6,6,"After publishing A Time to Kill, which book di...",A Time to Kill,No Country for Old Men,Q1213715,Q611689,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
7,7,"After publishing A Time to Kill, which book di...",A Time to Kill,No Country for Old Men,Q1213715,Q60411383,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q186...","After publishing A Time to Kill, which book di..."
8,8,"After publishing A Time to Kill, which book di...",A Time to Kill,Slaughterhouse-Five,Q1213715,Q265954,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."
9,9,"After publishing A Time to Kill, which book di...",A Time to Kill,The Firm,Q1213715,Q1212467,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di..."


In [None]:
text = data_test['question'][0] + ' ' + data_test['answerEntity'][5]
tokenized_text = tokenizer(text)
text

'After publishing A Time to Kill, which book did its author begin working on immediately? Kongenes kamp'

In [None]:
with torch.no_grad():
    input = torch.tensor(tokenized_text['input_ids']).unsqueeze(0)
    output = model(input)
output.item()

4.27785587310791

In [None]:
data_test['correct'] = False

In [None]:
for question in tqdm(data_test['question'].unique()):
    lines = data_test[data_test['question'] == question]

    candidates = torch.zeros(len(lines['answerEntity']))
    for i, answer in enumerate(lines['answerEntity']):
        candidate = question + ' ' + answer
        tokenized_text = tokenizer(candidate)

        with torch.no_grad():
          input = torch.tensor(tokenized_text['input_ids']).unsqueeze(0)
          output = model(input)

        candidates[i] = output.item()

    right_answer = list(lines['answerEntity'])[torch.argmax(candidates).item()]
    data_test.loc[(data_test['question'] == question) & (data_test['answerEntity'] == right_answer), 'correct'] = True
    print(question + ' ' + right_answer)

  0%|          | 0/1000 [00:00<?, ?it/s]

After publishing A Time to Kill, which book did its author begin working on immediately? The Firm
Among the European Union countries, which one has the largest land area? Italy
Among the Final Fantasy games, which installment achieved the highest worldwide sales? Final Fantasy
At which Academy Awards was the Leonardo DiCaprio nominated for the first time? Academy Award for Best Picture
Between Ulysses S Grant and Abe Lincoln, which one was president first? Ulysses S. Grant
Did more Americans die in World War II or the American Civil War? civil war
Directed, written, produced, and co-edited by James Cameron, in which movie Kate Winslet and Leonardo DiCaprio starred together?  Avatar
Donald John Trump's first wife was born in? England
Due to which disease did the composer of La campanella died from? tuberculosis
During WWI, which country was a part of Triple Alliance ? Germany
During what war did Spanish Empire lose Philippines? Spanish American Wars of Independence
First female Pharaoh 

In [None]:
data_test.head()

Unnamed: 0,sample_id,question,questionEntity,answerEntity,questionEntityId,answerEntityId,graph,texts,correct
0,0,"After publishing A Time to Kill, which book di...",A Time to Kill,A Clash of Kings,Q1213715,Q300370,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di...",False
1,1,"After publishing A Time to Kill, which book di...",A Time to Kill,A Feast for Crows,Q1213715,Q1764445,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di...",False
2,2,"After publishing A Time to Kill, which book di...",A Time to Kill,Fear and Loathing in Las Vegas,Q1213715,Q772435,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di...",False
3,3,"After publishing A Time to Kill, which book di...",A Time to Kill,In Cold Blood,Q1213715,Q1142887,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di...",False
4,4,"After publishing A Time to Kill, which book di...",A Time to Kill,Into the Woods,Q1213715,Q1118244,"{'nodes': [{'type': 'INTERNAL', 'name_': 'Q30'...","After publishing A Time to Kill, which book di...",False


## Saving Results

In [None]:
data_test['prediction'] = data_test['correct'].astype(int)

In [None]:
test_pred_path = "test_result.tsv"
data_test[["sample_id", "prediction"]].to_csv(test_pred_path, sep='\t', index=False)