<a href="https://colab.research.google.com/github/parvathysarat/kg-qa/blob/master/qa_task_metaqa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Size of our transG embeddings : 50 (both entities and relations)

(contains : 
- utility functions to get entities, relations, documents from MetaQA processed files
- training functions
- embedding functions
- main
<br> Needs to be split into util, training and main files. Also need to add config file with specifications).

In [0]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import os
import time, datetime
import sys
import json
import os

In [0]:
!git clone https://github.com/parvathysarat/kg-qa
# ./kg-qa/data/transg/ has the embeddings obtained by TransG for MetaQA dataset

Cloning into 'kg-qa'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (42/42), done.[K
remote: Total 46 (delta 16), reused 16 (delta 2), pack-reused 0[K
Unpacking objects: 100% (46/46), done.


In [0]:
os.chdir('./kg-qa/data/transg/')

### Loading MetaQA datasets - entities, relations

In [0]:
import glob
import shutil
for file in glob.glob('./kg-qa/data/*.txt'):
  shutil.move(file,'./kg-qa/data/transg/')

In [5]:
# entities.txt from MetaQa : list of entities

def get_entities_relns():
  with open('../MetaQA/entities.txt') as f:
    entities = {id:line.strip() for id,line in enumerate(f)}
    print(len(entities),entities[0])
  with open('../MetaQA/relations.txt') as f:
    relations = {id:line.strip() for id,line in enumerate(f)}
    print(len(relations),relations[0])
    return entities, relations    
  
entities, relations = get_entities_relns()

43234 Kismet
9 has_imdb_rating


### Loading pretrained entity weights (transG) for MetaQA

In [0]:
# entity.txt storing TransG embeddings of MetaQA

def get_num_entity(param='num'):
   with open('entity.txt') as f:
    if param=='num':
      return sum([1 for line in f])
    if param=='weights':
      arr = []
      weights_dict = {}
      ct=0
      for line in f:
        entity = entities[ct]
        ct+=1
        for idx,el in enumerate(line.split()):
          # numeric entities are not vectorized
          if el.isnumeric() and idx==0: 
            continue

          if (' '.join(line.split()[:idx]))==entity:
            arr.append(np.array(line.split()[idx:],dtype=np.float32))
            weights_dict[entity] = np.array(line.split()[idx:],dtype=np.float32)
            break
      print(len(weights_dict),ct)
    pretrained_weights = torch.FloatTensor(arr)  
    return weights_dict, pretrained_weights
num_entities = get_num_entity()

In [7]:
num_entities = get_num_entity()
entity_weights_dict, pretrained_entity_weights = get_num_entity('weights')

43233 43234


###Store pretrained weights to embedding vectors

In [0]:
def initialize_embeddings(pretrained_weights):
  embeddings = nn.Embedding.from_pretrained(pretrained_weights)  
  # entity_embeddings = nn.Embedding(num_embeddings=num_entities+1, embedding_dim=50,padding_idx=num_entities)
  # entity_embeddings.weight = nn.Parameter(get_num_entity('weights'))
  embeddings.weight.requires_grad = False  
  return embeddings
entity_embeddings = initialize_embeddings(pretrained_entity_weights)

In [9]:
entity_embeddings

Embedding(43233, 50)

### Loading pretrained relation weights

In [0]:
# function to return the cluster number for each relation
# (n_cluster # of embeddings for each relation from TransG based on GMM)

def get_relations_clust():
  relations_cluster = {}
  with open('weight.txt') as f:
    for line in f:
      relations_cluster[line.split()[0]] = np.argmax(np.array(line.split()[1:], dtype = np.float32))
  return relations_cluster
rel_clusters = get_relations_clust()

(transG uses GMM model hence relations have a mixture of embedding vectors to represent multiple semantic relations, weights of each mixture/cluster stored in weights.txt, loaded into rel_clusters. Here we have used num_clusters=4)

In [11]:
rel_clusters

{'directed_by': 2,
 'has_genre': 3,
 'has_imdb_rating': 3,
 'has_imdb_votes': 3,
 'has_tags': 3,
 'in_language': 2,
 'release_year': 3,
 'starred_actors': 2,
 'written_by': 2}

In [0]:
def get_rel_embeddings(relation):
  with open('relation_'+relation+'.txt') as f:
    for line in f:
      if int(line[0])==rel_clusters[relation]:
        return np.array(line.split()[1:],dtype=np.float32)

rel_weights_dict= {}
pretrained_relation_weights = []
for rel in rel_clusters:
  rel_weights_dict[rel] = get_rel_embeddings(rel)
  pretrained_relation_weights.append(rel_weights_dict[rel])
pretrained_relation_weights = torch.FloatTensor(pretrained_relation_weights)  

In [0]:
relation_embeddings = initialize_embeddings(pretrained_relation_weights)

In [14]:
print(entity_embeddings,relation_embeddings)

Embedding(43233, 50) Embedding(9, 50)


Mapping vocabulary, entities and relations to ids. 
dicts of {value:id} format

In [0]:
def map_ids(file):
  id_map = {}
  with open('../MetaQA/'+file+'.txt') as f:
    for line in f:
      id_map[line.strip()] = len(id_map)
  return id_map

vocab_ids = map_ids('vocab')
relation_ids = map_ids('relations')
entitiy_ids = map_ids('entities')

In [22]:
from pprint import pprint
import tqdm
def map_documents(file):
  documents = {}
  with open(file,encoding='utf-8') as f:
    for line in f:
      
      line = json.loads(line)
      documents[line['documentId']] = document['document']['text'] 
      if 'title' in line:
        documents[line['documentId']] += " / " + document['title']['text']

      break

map_documents('../MetaQA/documents.json')

{'document': {'entities': [{'end': 6,
                            'kb_id': 1721,
                            'start': 4,
                            'text': 'romantic comedy'},
                           {'end': 6,
                            'kb_id': 4958,
                            'start': 4,
                            'text': 'Romantic Comedy'},
                           {'end': 11,
                            'kb_id': 12463,
                            'start': 9,
                            'text': 'john turturro'},
                           {'end': 11,
                            'kb_id': 15468,
                            'start': 9,
                            'text': 'John Turturro'},
                           {'end': 16,
                            'kb_id': 27467,
                            'start': 14,
                            'text': 'Brandon Cole'},
                           {'end': 1,
                            'kb_id': 27466,
                            'star

In [0]:
?tqdm