<a href="https://colab.research.google.com/github/zhangxingeng/SpikeKG/blob/main/SpikeKG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Here is the source of the countries dataset [countries](https://github.com/ZhenfengLei/KGDatasets/tree/master/Countries)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import numpy as np
import os

## Data Loading

In [None]:
def get_rel2id(datapath, fname):
  '''
  Convenience function returning a dictionary that turns relation names into ids.
  '''
  file = open(os.path.join(datapath, fname))
  content = file.readlines()
  ids, nodenames = [], []
  for i in range(len(content)):
    a = content[i].split('\t')
    ids.append(int(a[0]))
    nodenames.append(a[-1][:-1])
  rel2id = dict(zip(nodenames, ids))
  return rel2id

def get_id2rel(datapath):
  '''
  Convenience function returning a dictionary that turns ids into relation names.
  '''
  rel2id = get_rel2id(datapath)
  id2rel = dict(zip(rel2id.values(), rel2id.keys()))
  return id2rel

def get_ent2id(datapath, fname):
  '''
  Convenience function returning a dictionary that turns entity names into ids.
  '''
  file = open(os.path.join(datapath, fname))
  content = file.readlines()
  ids, nodenames = [], []
  for i in range(len(content)):
    a = content[i].split('\t')
    ids.append(int(a[0]))
    nodenames.append(a[-1][:-1])
  ent2id = dict(zip(nodenames, ids))
  return ent2id

def get_id2ent(datapath, fname):
  '''
  Convenience function returning a dictionary that turns ids into entity names.
  '''
  ent2id = get_ent2id(datapath, fname)
  id2ent = dict(zip(ent2id.values(), ent2id.keys()))
  return id2ent

def load_data(datapath):
  train_data = np.array(np.loadtxt(os.path.join(datapath, 'train.del')), dtype=int)
  valid_data = np.array(np.loadtxt(os.path.join(datapath, 'valid.del')), dtype=int)
  num_nodes = np.max([np.max(train_data[:,0]), np.max(train_data[:,2])])+1
  num_predicates = np.max(train_data[:,1])+1
  return train_data, valid_data, num_nodes, num_predicates

In [None]:
datapath = '/content/gdrive/MyDrive/SpikeKG/data/Countries_S1'
ent2id = get_ent2id(datapath, 'entity_ids.del')
rel2id = get_rel2id(datapath, 'relation_ids.del')
id2ent = get_id2ent(datapath, 'entity_ids.del')
train_data, valid_data, num_nodes, num_predicates = load_data(datapath)

print(f"ent2id:{dict(list(ent2id.items())[:4])}...")
print(f"rel2id:{dict(list(rel2id.items())[:4])}...")
print(f"id2ent:{dict(list(id2ent.items())[:4])}...")

print(f"train-data shape:{np.shape(train_data)}, valid-data shape: {np.shape(valid_data)}")
print(f"total # ents: {num_nodes}, total # rels: {num_predicates}")

ent2id:{'western_africa': 0, 'africa': 1, 'slovakia': 2, 'ukraine': 3}...
rel2id:{'locatedin': 0, 'neighbor': 1}...
id2ent:{0: 'western_africa', 1: 'africa', 2: 'slovakia', 3: 'ukraine'}...
train-data shape:(1111, 3), valid-data shape: (24, 3)
total # ents: 271, total # rels: 2


## Filters For avoid direct result in training dataset

In [None]:
def filter_data_sp(train_data, valid_data, num_nodes):
  '''
  Filters for objects.
  Given a triple spo, all objects x are removed that
  appear as spx in the training or validation data.
  '''
  filtered_sp = [[] for i in range(len(valid_data))]
  for i in range(len(valid_data)):
    subj, pred, obj = valid_data[i]
    for tr in train_data:
      if tr[0] == subj and tr[1] == pred:
        filtered_sp[i].append(tr[2])
    for tr in valid_data:
      if tr[0] == subj and tr[1] == pred:
        filtered_sp[i].append(tr[2])
    filtered_sp[i] = list(set(range(num_nodes)).difference(set(filtered_sp[i])))
    filtered_sp[i] = [obj] + filtered_sp[i]
  return np.array(filtered_sp, dtype=object)

def filter_data_po(train_data, valid_data, num_nodes):
  '''
  Filters for subjects.
  Given a triple spo, all subjects x are removed that
  appear as xpo in the training or validation data.
  '''
  filtered_po = [[] for i in range(len(valid_data))]
  for i in range(len(valid_data)):
    subj, pred, obj = valid_data[i]
    for tr in train_data:
      if tr[2] == obj and tr[1] == pred:
        filtered_po[i].append(tr[0])
    for tr in valid_data:
      if tr[2] == obj and tr[1] == pred:
        filtered_po[i].append(tr[0])
    filtered_po[i] = list(set(range(num_nodes)).difference(set(filtered_po[i])))
    filtered_po[i] = [subj] + filtered_po[i]
  return np.array(filtered_po, dtype=object)
  
def filter_data(datapath):
  '''
  Create filters for data (neglect known statements that are ranked
  higher during testing than the evaluated triple).
  '''
  train_data, valid_data, num_nodes, _ = load_data(datapath)

  valid_filter_sp = filter_data_sp(train_data, valid_data, num_nodes)
  valid_filter_po = filter_data_po(train_data, valid_data, num_nodes)
  train_filter_sp = filter_data_sp(valid_data, train_data, num_nodes)
  train_filter_po = filter_data_po(valid_data, train_data, num_nodes)

  np.save(os.path.join(datapath, 'valid_filter_sp'), valid_filter_sp)
  np.save(os.path.join(datapath, 'valid_filter_po'), valid_filter_po)
  np.save(os.path.join(datapath, 'train_filter_sp'), train_filter_sp)
  np.save(os.path.join(datapath, 'train_filter_po'), train_filter_po)

In [None]:
if not os.path.exists('{}/valid_filter_po.npy'.format(datapath)):
  filter_data(datapath)
  print('Filtered metrics created.')

## Batch Creation

In [None]:
from copy import deepcopy
class batch_provider:
  def __init__(self, data, batchsize, num_negSamples = 2, seed = 1231245):
    '''
    Helper class to provide data in batches with negative examples.
    data: Training data triples
    batchsize: size of the mini-batches
    num_negSamples: number of neg. samples.
    seed: random seed for neg. sample generation
    '''
    self.data = deepcopy(data)
    self.num_nodes = np.max([np.max(data[:,0]), np.max(data[:,2])])

    np.random.seed(seed)
    np.random.shuffle(self.data)

    self.batchsize = batchsize
    self.number_minibatches = int(len(self.data)/batchsize)
    self.current_minibatch = 0

    self.num_negSamples = num_negSamples

  def next_batch(self):
    '''
    Return the next mini-batch.
    Data triples are shuffled after each epoch.
    '''
    i = self.current_minibatch
    di = self.batchsize
    mbatch = deepcopy(self.data[i*di:(i+1)*di])
    self.current_minibatch += 1
    if self.current_minibatch == self.number_minibatches:
      np.random.shuffle(self.data)
      self.current_minibatch = 0
    if self.num_negSamples > 0:
      subj, pred, obj, labels = self.apply_neg_examples(list(mbatch[:,0]), list(mbatch[:,1]), list(mbatch[:,2]))
      return subj, pred, obj, labels
    else:
      return mbatch[:,0], mbatch[:,1], mbatch[:,2]

  def apply_neg_examples(self, subj, pred, obj):
    '''
    Generate neg. samples for a mini-batch.
    Both subject and object neg. samples are generated.
    '''
    vsize = len(subj)
    labels = np.array([1 for i in range(vsize)] + [-1 for i in range(self.num_negSamples*2*vsize)])
    neg_subj = list(np.random.randint(self.num_nodes, size = self.num_negSamples*vsize))
    neg_obj = list(np.random.randint(self.num_nodes, size = self.num_negSamples*vsize))
    return np.concatenate([subj, neg_subj, subj*self.num_negSamples]), np.concatenate([pred*(2*self.num_negSamples+1)]), np.concatenate([obj, obj*self.num_negSamples, neg_obj]), labels


In [None]:
from argparse import Namespace
params_dict = {
  'dim': 40,
  'input_size': 40,
  'tau': 0.5,
  'batchsize': 64,
  'delta': 0.001,
  'lr': .1,
  'L2': 0.,
  'steps': 801,
  'neg_samples': 10,
  'maxspan': 2,
}
params = Namespace(**params_dict)

eval_points = list(range(0, 1001, 200))
seed = np.random.randint(1e8)

batcher = batch_provider(train_data, params.batchsize, params.neg_samples, seed)
 

In [None]:
# first_batch = batcher.next_batch
print(batcher.next_batch()[0])

[127  53 231 ... 203 167  66]


## How to measure distance

In [None]:
def ASYmmetric_score(s_emb, o_emb, p_emb):
  ''' TransE score '''
  return -F.pairwise_distance(s_emb - o_emb, p_emb, p=1)

def SYMmetric_score(s_emb, o_emb, p_emb):
  ''' Symmetric version of TransE score '''
  return -F.pairwise_distance((s_emb - o_emb).abs(), p_emb, p=1)


## Training

In [None]:
def train(steps, model, optimizer, batcher, data):
  if data is not None:
    train_data, valid_data, neg_data = data[0], data[1], data[2]
  loss_fun = torch.nn.SoftMarginLoss()

  start_time = time.time()
  for k in range(steps):
    if k % 100 == 0:
      estimate_time = (time.time()-start_time)/(k+1)*(steps-k)/60.0
      if data is not None:
        train_score = float(torch.mean(model.score(*train_data)).detach().numpy())
        valid_score = float(torch.mean(model.score(*valid_data)).detach().numpy())
        neg_score = float(torch.mean(model.score(*neg_data)).detach().numpy())
        print('SpikE {}: ETA {}min; train: {}, valid: {}, neg: {}'.format(k, np.round(estimate_time,2), train_score, valid_score, neg_score))
      else:
        print('SpikE {}: ETA {}min'.format(k, np.round(estimate_time,2)))
    optimizer.zero_grad()
    databatch = batcher.next_betch()
    prediction = model.score(*databatch)
    weight_reg = model.entities.weight_loss()
    loss = loss_fun(prediction, torch.tensor(databatch[-1]))
    loss = loss + delta*weight_reg
    loss.backward()
    optimizer.step()
    model.update_embeddings()