### Build model and test follow operation

In [1]:
import os
import sys 
from pathlib import Path 

%load_ext autoreload
%autoreload 2


base_dir = Path(os.getcwd()).parents[0]
sys.path.append(str(base_dir))

from utils import data_utils

### Load Triplets

In [2]:
data_dir = base_dir/'data'
kg_path = data_dir/'kb.txt'
os.listdir(data_dir)
assert kg_path.exists()

triplets, entity_to_idx, relation_to_idx, idx_to_entity, idx_to_relation = data_utils.load_triplets(kg_path)
num_entities = len(entity_to_idx)

num entities: 43234
num relations: 9
num triplets  134741


### Load question transformer

In [3]:
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']

# # Load model from HuggingFace Hub
# tokenizer = AutoTokenizer.from_pretrained('microsoft/MiniLM-L12-H384-uncased')
# model = AutoModel.from_pretrained('microsoft/MiniLM-L12-H384-uncased')




tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
trans_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
trans_model.train()


for param in trans_model.parameters():
    param.requires_grad = False
    
for layer in trans_model.encoder.layer[4].parameters():
    layer.requires_grad_ = True

for param in trans_model.encoder.layer[5].parameters():
    param.requires_grad = True
    
for param in trans_model.pooler.parameters(): 
    param.requires_grad = True

### Load dataset 

In [4]:
import warnings

qa_dir = data_dir/'1-hop-20220731T054004Z-001/1-hop/vanilla'
qa_paths = os.listdir(qa_dir)

qa_train = list(filter(lambda x: 'train' in x, qa_paths))[0]
qa_test = list(filter(lambda x: 'test' in x, qa_paths))[0]
qa_val = list(filter(lambda x: 'dev' in x, qa_paths))[0]

val_pairs = data_utils.load_qa_pairs(qa_dir/qa_val)
train_pairs = data_utils.load_qa_pairs(qa_dir/qa_train)
test_pairs = data_utils.load_qa_pairs(qa_dir/qa_test)


data_utils.santity_check(val_pairs, entity_to_idx)

Num qa pairs loaded: 9992
Num qa pairs loaded: 96106
Num qa pairs loaded: 9947
number of entities missing 0: []


### Tokenize dataset

In [5]:
# # Tokenize sentences
val_tokens = tokenizer([row[0] for row in val_pairs], padding=True, truncation=True, max_length=100, return_tensors='pt')
train_tokens = tokenizer([row[0] for row in train_pairs], padding=True, truncation=True, max_length=100, return_tensors='pt')
test_tokens = tokenizer([row[0] for row in test_pairs], padding=True, truncation=True, max_length=100, return_tensors='pt')

### Create Torch Dataset

In [23]:
class QADataset(Dataset):
    def __init__(self, qa_pairs, q_tokens, entity_to_idx):
        self.qa_pairs = qa_pairs
        self.entity_to_idx = entity_to_idx
        self.q_tokens = q_tokens

    def __len__(self):
        return len(self.qa_pairs)

    def __getitem__(self, idx):
        pad = 5
        token_sample = {key: self.q_tokens[key][idx, :] for key in self.q_tokens}
        
        head_entities = self.qa_pairs[idx][2]
        head_entities = [self.entity_to_idx[entity] for entity in head_entities]
                
        head_entities = self.create_onehot_entity_vector(head_entities)
        
        tail_entities = self.qa_pairs[idx][1]
        tail_entities = [self.entity_to_idx[entity] for entity in tail_entities]
        
        tail_entities = self.create_onehot_entity_vector(tail_entities)
        
        return token_sample, head_entities, tail_entities
        
    
    def pad_entity(self, list1, pad=5):
        # max length should be this 
        if len(list1) > pad:
            list1 = list1[:pad]
        
        assert len(list1) > 0
        
        list1 = list1 + [-1] * (pad-len(list1))
        return list1
    
    def create_onehot_entity_vector(self, entities):
        """
        Inputs: entities: list of ints 
        
        """
        num_entities = len(self.entity_to_idx)
        
        entity_tensor = torch.zeros(num_entities)
        entity_tensor[entities] = 1
        
        return entity_tensor
        
        
        
val_dataset = QADataset(val_pairs, val_tokens, entity_to_idx)
test_dataset = QADataset(test_pairs, test_tokens, entity_to_idx)
train_dataset = QADataset(train_pairs, train_tokens, entity_to_idx)

val_dl = DataLoader(val_dataset, batch_size=30)
train_dl = DataLoader(train_dataset, batch_size=30, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=10)

for batch in train_dl:
    break
    

In [7]:
# ## Test forward pass of transfomer 

# with torch.no_grad():
#     out = model(**batch[0])
#     out.pooler_output

### Create differentiable KG

In [8]:
subject_matrix, rel_matrix, object_matrix = data_utils.create_differentiable_kg(triplets, entity_to_idx, relation_to_idx)
object_matrix = torch.transpose(object_matrix, 0, 1)


### Create model

In [24]:
import models_onehop

net = models_onehop.GNNLightning(trans_model, subject_matrix, rel_matrix, object_matrix)

In [25]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

max_epochs = 1 
USE_GPU = True 

early_stop_callback = EarlyStopping(monitor="val_loss", 
                                    min_delta=0.0005,
                                    patience=1, 
                                    verbose=False, 
                                    mode="max")

checkpoint_callback = ModelCheckpoint(monitor=f"val_loss",
                                    save_top_k=1,
                                      dirpath='checkpoints',
                                      mode='max',
                                    )

callbacks=[early_stop_callback, checkpoint_callback]

if USE_GPU:
    net.object_matrix = net.object_matrix.cuda()
    net.subject_matrix = net.subject_matrix.cuda()
    net.rel_matrix = net.rel_matrix.cuda()
    gpus = 1 
else:
    net = net.cpu()
    net.object_matrix = net.object_matrix.cpu()
    net.subject_matrix = net.subject_matrix.cpu()
    net.rel_matrix = net.rel_matrix.cpu()
    
    gpus = 0
    
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus)
trainer.fit(net, train_dl, val_dl)

CKPT_PATH = trainer.checkpoint_callback.best_model_path

# net = net.load_from_checkpoint(CKPT_PATH, input_dim=dataset.num_node_features, hidden_dim=300, output_dim=dataset.num_classes, split_idx=None)
# res = trainer.validate(net, loader)

# test_res = trainer.test(net, loader)
# valid_metrics.append(test_res[0]['test_acc'])

# print('Number of epochs ran ',  trainer.current_epoch)
# print('test metrics ', test_res)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type              | Params
--------------------------------------------------
0 | trans_model | BertModel         | 22.7 M
1 | decoder     | Linear            | 3.5 K 
2 | loss        | BCEWithLogitsLoss | 0     
--------------------------------------------------
22.7 M    Trainable params
0         Non-trainable params
22.7 M    Total params
90.867    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6932, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6931, device='cuda:0',


In [28]:
net = net.cpu()
net.object_matrix = net.object_matrix.cpu()
net.subject_matrix = net.subject_matrix.cpu()
net.rel_matrix = net.rel_matrix.cpu()

In [29]:
net = net.train()
for idx, batch in enumerate(train_dl):
    with torch.no_grad():
        trans_input, subject_vector, object_labels = batch
        subject_vector2 = torch.transpose(subject_vector, 0, 1)

        predictions = net(trans_input, subject_vector2)
        loss = torch.nn.BCEWithLogitsLoss()(predictions, object_labels)
        
        print('predictions sum', predictions.sum())
        
        if idx > 5:
            break
    
    
print('done')

predictions sum tensor(-632.9448)
predictions sum tensor(-276.6538)
predictions sum tensor(-477.4096)
predictions sum tensor(-601.9507)
predictions sum tensor(-418.9074)
predictions sum tensor(-511.5621)
predictions sum tensor(-323.6661)
done


In [31]:
net = net.cpu()
net.object_matrix = net.object_matrix.cpu()
net.subject_matrix = net.subject_matrix.cpu()
net.rel_matrix = net.rel_matrix.cpu()

In [33]:
import transformers

net = net.train()
optimizer = transformers.AdamW(lr=5e-4, params=net.parameters())

for idx, batch in enumerate(train_dl):
    trans_input, subject_vector, object_labels = batch
    subject_vector2 = torch.transpose(subject_vector, 0, 1)

    predictions = net(trans_input, subject_vector2)
    loss = torch.nn.BCEWithLogitsLoss()(predictions, object_labels)
    
    
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('loss ', loss)
    
    
    if idx > 5:
        break
    
    
print('done')

loss  tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6930, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6930, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
loss  tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
done
