### Build model and test follow operation

In [1]:
import os
import sys 
from pathlib import Path 

%load_ext autoreload
%autoreload 2


base_dir = Path(os.getcwd()).parents[0]
sys.path.append(str(base_dir))

from utils import data_utils

### Load Triplets

In [2]:
data_dir = base_dir/'data'
kg_path = data_dir/'kb.txt'
os.listdir(data_dir)
assert kg_path.exists()

triplets, entity_to_idx, relation_to_idx, idx_to_entity, idx_to_relation = data_utils.load_triplets(kg_path)
num_entities = len(entity_to_idx)

num entities: 43234
num relations: 9
num triplets  134741


### Load question transformer

In [3]:
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('microsoft/MiniLM-L12-H384-uncased')
model = AutoModel.from_pretrained('microsoft/MiniLM-L12-H384-uncased')

### Load dataset 

In [4]:
import warnings

qa_dir = data_dir/'1-hop-20220731T054004Z-001/1-hop/vanilla'
qa_paths = os.listdir(qa_dir)

qa_path = qa_dir/qa_paths[0]

print(qa_path)
val_pairs = data_utils.load_qa_pairs(qa_path)
data_utils.santity_check(val_pairs, entity_to_idx)

C:\Users\Sidhant\Documents\projects_research\graph_qa\kgqa\data\1-hop-20220731T054004Z-001\1-hop\vanilla\qa_dev.txt
Num qa pairs loaded: 9992
number of entities missing 0: []


In [5]:
# [len(row[2]) for row in val_pairs]

### Tokenize dataset

In [6]:
# # Tokenize sentences
sentences = [row[0] for row in val_pairs]
val_tokens = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


### Create Torch Dataset

In [7]:
class QADataset(Dataset):
    def __init__(self, qa_pairs, q_tokens, entity_to_idx):
        self.qa_pairs = qa_pairs
        self.entity_to_idx = entity_to_idx
        self.q_tokens = q_tokens

    def __len__(self):
        return len(self.qa_pairs)

    def __getitem__(self, idx):
        pad = 5
        token_sample = {key: self.q_tokens[key][idx, :] for key in self.q_tokens}
        
        head_entities = self.qa_pairs[idx][2]
        head_entities = [self.entity_to_idx[entity] for entity in head_entities]
                
        # padded
        head_entities = self.create_onehot_entity_vector(head_entities)
#         head_entities = torch.LongTensor(head_entities)
        
        tail_entities = self.qa_pairs[idx][1]
        tail_entities = [self.entity_to_idx[entity] for entity in tail_entities]
        
        tail_entities = self.create_onehot_entity_vector(tail_entities)
#         tail_entities = torch.LongTensor(tail_entities)
        
        return token_sample, head_entities, tail_entities
        
    
    def pad_entity(self, list1, pad=5):
        # max length should be this 
        if len(list1) > pad:
            list1 = list1[:pad]
        
        assert len(list1) > 0
        
        list1 = list1 + [-1] * (pad-len(list1))
        return list1
    
    def create_onehot_entity_vector(self, entities):
        """
        Inputs: entities: list of ints 
        
        """
        num_entities = len(self.entity_to_idx)
        
        entity_tensor = torch.zeros(num_entities)
        entity_tensor[entities] = 1
        
        return entity_tensor
        
        
        
val_dataset = QADataset(val_pairs, val_tokens, entity_to_idx)


dl = DataLoader(val_dataset, batch_size=10)

for batch in dl:
    break
    
batch[0]['input_ids'].shape

torch.Size([10, 23])

In [8]:
## Test forward pass of transfomer 

with torch.no_grad():
    out = model(**batch[0])
    out.pooler_output

### Create differentiable KG

In [9]:
subject_matrix, rel_matrix, object_matrix = data_utils.create_differentiable_kg(triplets, entity_to_idx, relation_to_idx)
object_matrix = torch.transpose(object_matrix, 0, 1)


### Create model

In [12]:
import models_onehop

net = models_onehop.GNNLightning(model, subject_matrix, rel_matrix, object_matrix)

### Test Follow Operation

In [15]:
# def create_sparse_subject():
#     bs = batch[1].shape[0]
#     bs_ids = torch.Tensor(range(0, bs))
#     indices = torch.cat([subject_ids[None, :], bs_ids[None, :]])
#     values = torch.FloatTensor([1] * len(bs_ids))
#     subject_vector = torch.sparse_coo_tensor(indices, values, size=(len(entity_to_idx), bs))
    
#     return subject_vector

# # subject_vector = create_sparse_subject()
# subject_vector = torch.transpose(subject_vector, 0, 1)

In [16]:
subject_vector = batch[1]
bs = batch[1].shape[0]
relation = torch.randn(bs, 9)

print(subject_vector.shape, subject_matrix.shape)
print('relation: ', relation.shape)
print('subject_vector:', subject_vector.shape)
print('subject_matrix:', subject_matrix.shape)
print('rel_matrix :', rel_matrix.shape)
print('object_matrix1:', object_matrix.shape)

subject_vector = torch.transpose(subject_vector, 0, 1)
out = net.follow(subject_vector, relation, subject_matrix, rel_matrix, object_matrix)
print(out.shape)

torch.Size([10, 43234]) torch.Size([134741, 43234])
relation:  torch.Size([10, 9])
subject_vector: torch.Size([10, 43234])
subject_matrix: torch.Size([134741, 43234])
rel_matrix : torch.Size([134741, 9])
object_matrix1: torch.Size([43234, 134741])
torch.Size([10, 43234])


### Perform Follow operation on entity and get results 

In [17]:
from collections import Counter

def get_sample(id_num, triplets):
    data = list(filter(lambda x: x[0] == id_num, triplets))
        
    # get the relation that leads to multiple entities 
    relation_name = Counter([x[1] for x in data]).most_common(5)[0][0]
    
    objects = []
    for row in data: 
        if row[1] == relation_name:
            objects.append(row[2])
            
    print('num objects ', len(objects))
    return relation_name, objects

subject_id = 0
relation_id, object_ids = get_sample(subject_id, triplets)

def get_entity_names(head_id, rel_id, object_ids):
    head_name = idx_to_entity[head_id]
    rel_name = idx_to_relation[rel_id]
    
    object_name = [idx_to_entity[idx] for idx in object_ids]
    
    print(f'Head {head_name} with relation {rel_name} leads to {object_name}')
    
  
print(f'For head_id {subject_id}, these are the expected results')
get_entity_names(subject_id, relation_id, object_ids)

num objects  4
For head_id 0, these are the expected results
Head Kismet with relation starred_actors leads to ['Marlene Dietrich', 'Edward Arnold', 'Ronald Colman', 'James Craig']


In [18]:
import pdb

def get_subject_relation_vector(subject_id, relation_id):
    subject_vector = torch.zeros(1, len(entity_to_idx))
    subject_vector[0, subject_id] = 1 
    
    relation_vector = torch.zeros(1, len(relation_to_idx))
    relation_vector[0, relation_id] = 1 
    
    return subject_vector, relation_vector 

subject_vector, relation_vector = get_subject_relation_vector(subject_id, relation_id)
subject_vector = torch.transpose(subject_vector, 0, 1) # required size
out = net.follow(subject_vector, relation_vector, subject_matrix, rel_matrix, object_matrix)

def interpred_follow_output(out, threshold=0.5):
    """
    Interpret the follow result from the model
    """
    bs, num_entities = out.shape 
        
    outputs_names = []
    output_ids = []
    for idx in range(bs):
        object_probs = out[idx, :]
        
        # set all values less than threshold to be 0
        condition = object_probs >= threshold
        object_preds = object_probs.where(condition, torch.tensor(0.))
        # get indices where value is non zero (true preds)
        pred_object_ids = object_preds.nonzero(as_tuple=True)[0]
        
        pred_object_names = [idx_to_entity[object_id.item()] for object_id in pred_object_ids]
        outputs_names.append(pred_object_names)
        output_ids.append(pred_object_ids.tolist())
            
    return outputs_names, output_ids
        
object_names, object_ids = interpred_follow_output(out, threshold=0.5)  
get_entity_names(subject_id, relation_id, object_ids[0])

Head Kismet with relation starred_actors leads to ['Marlene Dietrich', 'Edward Arnold', 'Ronald Colman', 'James Craig']


In [45]:
subject_vector, relation_vector = get_subject_relation_vector(subject_id, relation_id)
subject_vector = torch.transpose(subject_vector, 0, 1) # required size

# relation_vector[relation_vector==1] = 0.9
# relation_vector[relation_vector==0] = 0.1

relation_vector = torch.Tensor([-1.2259e+00, -6.9046e-01, -2.4655e+00, -3.0674e-02,  1.9986e+00,
         -3.3054e-01, -5.4131e-01, -2.4845e-02,  3.6497e-01])[None, :]

out = net.follow(subject_vector, relation_vector, subject_matrix, rel_matrix, object_matrix)
out

tensor([[ 0.0000, -1.2259, -0.6905,  ...,  0.0000,  0.0000,  0.0000]])

### Test forward pass

In [20]:
net = models_onehop.GNNLightning(model, subject_matrix, rel_matrix, object_matrix)

In [21]:
def transform_subject_matrix(subject_ids, num_entities):
    """
    Transforms subjectids in required format for the model. Meant only for single hop models 
    
    Inputs: 
        - subject_ids- tensor of subject ids. of shape [bs, num_subjects]. Since this is one subject model,
            we always take only first subject
    """
    
    # single subject 
    subject_ids = subject_ids[:, 0]    
    batch_size = subject_ids.shape[0]
    
    subject_vector = torch.zeros(batch_size, num_entities, device=subject_ids.device)
    subject_vector[range(0, subject_vector.shape[0]), subject_ids] = 1
    subject_vector = torch.transpose(subject_vector, 0, 1)
    
    return subject_vector
    
    

In [34]:
with torch.no_grad():
    trans_input = batch[0]
    subject_vector = batch[1]
    object_labels = batch[2]
    subject_vector2 = torch.transpose(subject_vector, 0, 1)
    
    batch_size = subject_vector2.shape[1]
    num_entities = subject_matrix.shape[1]
    
    # check shape
    assert num_entities == len(entity_to_idx)
    
    out = net.follow(subject_vector2, relation, subject_matrix, rel_matrix, object_matrix)
    
    predictions = net(trans_input, subject_vector2)
    
    loss = torch.nn.BCEWithLogitsLoss()(predictions, object_labels)
    print(loss)
    
    
print('done')

torch.Size([10, 9])
tensor(0.6931)
done


In [29]:
predictions.sum()

tensor(0.)

In [23]:
### Test

In [24]:
# def prepare_subject_vector():
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
# output = torch.full([10, 64], 1.5)  # A prediction (logit)
output = torch.randn([10, 64])

print(output.shape, target.shape)
criterion = torch.nn.BCEWithLogitsLoss()
criterion(output, target)

torch.Size([10, 64]) torch.Size([10, 64])


tensor(0.8148)

In [25]:
# subject_ids = subject_ids[:, 0]
# batch_size = subject_ids.shape[0]

# subject_vector = create_subject_vector(subject_ids, num_entities)