<a href="https://colab.research.google.com/github/zhangxingeng/SpikeKG/blob/main/SpikeKG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Here is the source of the countries dataset [countries](https://github.com/ZhenfengLei/KGDatasets/tree/master/Countries)

In [127]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [128]:
import numpy as np
import os, time
import torch
import matplotlib.pyplot as plt

## Data Loading

In [129]:
def get_rel2id(datapath, fname):
  '''
  Convenience function returning a dictionary that turns relation names into ids.
  '''
  file = open(os.path.join(datapath, fname))
  content = file.readlines()
  ids, nodenames = [], []
  for i in range(len(content)):
    a = content[i].split('\t')
    ids.append(int(a[0]))
    nodenames.append(a[-1][:-1])
  rel2id = dict(zip(nodenames, ids))
  return rel2id

def get_id2rel(datapath):
  '''
  Convenience function returning a dictionary that turns ids into relation names.
  '''
  rel2id = get_rel2id(datapath)
  id2rel = dict(zip(rel2id.values(), rel2id.keys()))
  return id2rel

def get_ent2id(datapath, fname):
  '''
  Convenience function returning a dictionary that turns entity names into ids.
  '''
  file = open(os.path.join(datapath, fname))
  content = file.readlines()
  ids, nodenames = [], []
  for i in range(len(content)):
    a = content[i].split('\t')
    ids.append(int(a[0]))
    nodenames.append(a[-1][:-1])
  ent2id = dict(zip(nodenames, ids))
  return ent2id

def get_id2ent(datapath, fname):
  '''
  Convenience function returning a dictionary that turns ids into entity names.
  '''
  ent2id = get_ent2id(datapath, fname)
  id2ent = dict(zip(ent2id.values(), ent2id.keys()))
  return id2ent

def load_data(datapath):
  train_data = np.array(np.loadtxt(os.path.join(datapath, 'train.del')), dtype=int)
  valid_data = np.array(np.loadtxt(os.path.join(datapath, 'valid.del')), dtype=int)
  node_num = np.max([np.max(train_data[:,0]), np.max(train_data[:,2])])+1
  predicate_num = np.max(train_data[:,1])+1
  return train_data, valid_data, node_num, predicate_num

## Filters For avoid direct result in training dataset

In [130]:
def filter_data_sp(train_data, valid_data, node_num):
  '''
  Filters for objects.
  Given a triple spo, all objects x are removed that
  appear as spx in the training or validation data.
  '''
  filtered_sp = [[] for i in range(len(valid_data))]
  for i in range(len(valid_data)):
    subj, pred, obj = valid_data[i]
    for tr in train_data:
      if tr[0] == subj and tr[1] == pred:
        filtered_sp[i].append(tr[2])
    for tr in valid_data:
      if tr[0] == subj and tr[1] == pred:
        filtered_sp[i].append(tr[2])
    filtered_sp[i] = list(set(range(node_num)).difference(set(filtered_sp[i])))
    filtered_sp[i] = [obj] + filtered_sp[i]
  return filtered_sp

def filter_data_po(train_data, valid_data, node_num):
  '''
  Filters for subjects.
  Given a triple spo, all subjects x are removed that
  appear as xpo in the training or validation data.
  '''
  filtered_po = [[] for i in range(len(valid_data))]
  for i in range(len(valid_data)):
    subj, pred, obj = valid_data[i]
    for tr in train_data:
      if tr[2] == obj and tr[1] == pred:
        filtered_po[i].append(tr[0])
    for tr in valid_data:
      if tr[2] == obj and tr[1] == pred:
        filtered_po[i].append(tr[0])
    filtered_po[i] = list(set(range(node_num)).difference(set(filtered_po[i])))
    filtered_po[i] = [subj] + filtered_po[i]
  return filtered_po
  
def filter_data(datapath):
  '''
  Create filters for data (neglect known statements that are ranked
  higher during testing than the evaluated triple).
  '''
  train_data, valid_data, node_num, _ = load_data(datapath)

  valid_filter_sp = filter_data_sp(train_data, valid_data, node_num)
  valid_filter_po = filter_data_po(train_data, valid_data, node_num)
  train_filter_sp = filter_data_sp(valid_data, train_data, node_num)
  train_filter_po = filter_data_po(valid_data, train_data, node_num)

  np.save(os.path.join(datapath, 'valid_filter_sp'), valid_filter_sp)
  np.save(os.path.join(datapath, 'valid_filter_po'), valid_filter_po)
  np.save(os.path.join(datapath, 'train_filter_sp'), train_filter_sp)
  np.save(os.path.join(datapath, 'train_filter_po'), train_filter_po)

## Batch Creation

In [131]:
from copy import deepcopy
class batch_provider:
  def __init__(self, data, batchsize, num_negSamples = 2, seed = 1231245):
    '''
    Helper class to provide data in batches with negative examples.
    data: Training data triples
    batchsize: size of the mini-batches
    num_negSamples: number of neg. samples.
    seed: random seed for neg. sample generation
    '''
    self.data = deepcopy(data)
    self.node_num = np.max([np.max(data[:,0]), np.max(data[:,2])])

    np.random.seed(seed)
    np.random.shuffle(self.data)

    self.batchsize = batchsize
    self.number_minibatches = int(len(self.data)/batchsize)
    self.current_minibatch = 0

    self.num_negSamples = num_negSamples

  def next_batch(self):
    '''
    Return the next mini-batch.
    Data triples are shuffled after each epoch.
    '''
    i = self.current_minibatch
    di = self.batchsize
    mbatch = deepcopy(self.data[i*di:(i+1)*di])
    self.current_minibatch += 1
    if self.current_minibatch == self.number_minibatches:
      np.random.shuffle(self.data)
      self.current_minibatch = 0
    if self.num_negSamples > 0:
      subj, pred, obj, labels = self.apply_neg_examples(list(mbatch[:,0]), list(mbatch[:,1]), list(mbatch[:,2]))
      return subj, pred, obj, labels
    else:
      return mbatch[:,0], mbatch[:,1], mbatch[:,2]

  def apply_neg_examples(self, subj, pred, obj):
    '''
    Generate neg. samples for a mini-batch.
    Both subject and object neg. samples are generated.
    '''
    vsize = len(subj)
    labels = np.array([1 for i in range(vsize)] + [-1 for i in range(self.num_negSamples*2*vsize)])
    neg_subj = list(np.random.randint(self.node_num, size = self.num_negSamples*vsize))
    neg_obj = list(np.random.randint(self.node_num, size = self.num_negSamples*vsize))
    return np.concatenate([subj, neg_subj, subj*self.num_negSamples]), np.concatenate([pred*(2*self.num_negSamples+1)]), np.concatenate([obj, obj*self.num_negSamples, neg_obj]), labels


## Train and Eval

In [132]:
import scipy.stats as ss

def score_triple(rel2id, ent2id, model, subj, pred, obj):
    '''
    Score single triple. Takes text name of graph entities as input.
    '''
    subj = torch.tensor([ent2id[subj]])
    obj = torch.tensor([ent2id[obj]])
    pred = torch.tensor([rel2id[pred]])
    return model.score(subj, pred, obj)

def get_rank_sp(model, data, whichOne, dataPath, savePath = None):
    '''
    Get metrics like mean rank, MRR, hits@k when the object is replaced.

    model: trained model used to evaluate data
    data: triples to be tested
    whichOne: which set is used (train, valid)
    dataPath: path of data folder
    savePath: if not None, path where results are stored
    '''
    ranks = []
    filtered_sp = np.load('{}/{}_filter_sp.npy'.format(dataPath, whichOne), allow_pickle=True)
    for i in range(len(data)):
        subj, pred = data[i][0], data[i][1]
        rankings = model.score([subj for j in range(len(filtered_sp[i]))], [pred for j in range(len(filtered_sp[i]))], filtered_sp[i]).detach().numpy()
        rank = ss.rankdata(-rankings)[0]
        ranks.append(rank)
    ranks_modes = [np.percentile(ranks, 25), np.median(ranks), np.percentile(ranks, 75), np.mean(ranks), np.mean(1/(np.array(ranks)))]
    hitsAt = []
    for hit in [1, 3, 10, 100, 500]: hitsAt.append(np.mean(np.array(ranks) <= hit))
    if savePath is None: return ranks, ranks_modes, hitsAt
    else:
        save_ranks(ranks, ranks_modes, hitsAt, savePath, whichOne, 'sp')

def get_rank_po(model, data, whichOne, dataPath, savePath = None):
    '''
    Get metrics like mean rank, MRR, hits@k when the subject is replaced.

    model: trained model used to evaluate data
    data: triples to be tested
    whichOne: which set is used (train, valid)
    dataPath: path of data folder
    savePath: if not None, path where results are stored
    '''
    ranks = []
    filtered_po = np.load('{}/{}_filter_po.npy'.format(dataPath, whichOne), allow_pickle=True)
    for i in range(len(filtered_po)):
        pred, obj = data[i][1], data[i][2]
        rankings = model.score(filtered_po[i], [pred for j in range(len(filtered_po[i]))], [obj for j in range(len(filtered_po[i]))]).detach().numpy()
        rank= ss.rankdata(-rankings)[0]
        ranks.append(rank)
    ranks_modes = [np.percentile(ranks, 25), np.median(ranks), np.percentile(ranks, 75), np.mean(ranks), np.mean(1/(np.array(ranks)))]
    hitsAt = []
    for hit in [1, 3, 10, 100, 500]: hitsAt.append(np.mean(np.array(ranks) <= hit))
    if savePath is None: return ranks, ranks_modes, hitsAt
    else: 
        save_ranks(ranks, ranks_modes, hitsAt, savePath, whichOne, 'po')

def save_ranks(ranks, ranks_modes, hitsAt, savePath, whichOne, mode):
    '''
    Save rank metric results.
    '''
    plt.close()
    plt.bar(np.arange(len(hitsAt)+1), hitsAt+[ranks_modes[-1]])
    plt.xticks(np.arange(len(hitsAt)+1), [1, 3, 10, 100, 500, 'MRR'])
    plt.ylim(0,1)
    plt.savefig('{}/{}_{}_hits_and_RMR.png'.format(savePath, whichOne, mode))
    np.savetxt('{}/{}_{}_hits.txt'.format(savePath, whichOne, mode), [[1, 3, 10, 100, 500], hitsAt])
    np.savetxt('{}/{}_{}_ranking.txt'.format(savePath, whichOne, mode), ranks)
    np.savetxt('{}/{}_{}_ranking_modes.txt'.format(savePath, whichOne, mode), ranks_modes)

def get_scores(model, train_data, valid_data, neg_data, savePath, post=''):
    '''
    Calculate and save scores of training data, validation data and a set of negative triples.
    '''
    train_scores = model.score(train_data[:,0], train_data[:,1], train_data[:,2]).detach().numpy()
    valid_scores = model.score(valid_data[:,0], valid_data[:,1], valid_data[:,2]).detach().numpy()
    neg_scores = model.score(neg_data[:,0], neg_data[:,1], neg_data[:,2]).detach().numpy()

    plt.close()
    _ = plt.hist(train_scores, bins = 100, alpha = 0.35, color = 'gray', density=True, label = 'train')
    _ = plt.hist(neg_scores, bins = 100, density=True, alpha = 0.7, histtype='step', linewidth = 2.5, label = 'neg')
    _ = plt.hist(valid_scores, bins = 100, density=True, alpha = 0.7, histtype='step', linewidth = 2.5, label = 'pos')
    plt.legend()
    plt.xlim(-10,.1)
    plt.savefig('{}/{}score_histogram.png'.format(savePath, post))
    np.save('{}/{}train_scores.npy'.format(savePath, post), train_scores)
    np.save('{}/{}valid_scores.npy'.format(savePath, post), valid_scores)
    np.save('{}/{}neg_scores.npy'.format(savePath, post), neg_scores)


In [133]:
def train(optimizer, batcher, model, delta, steps, data = None):
  if data is not None:
    train_data, valid_data, neg_data = data[0], data[1], data[2]
  loss_fun = torch.nn.SoftMarginLoss()

  start_time = time.time()
  for k in range(steps):
    # Print scores
    if k % 100 == 0:
      estimate_time = (time.time()-start_time)/(k+1)*(steps-k)/60.0
      if data is not None:
        train_score = float(torch.mean(model.score(*train_data)).detach().numpy())
        valid_score = float(torch.mean(model.score(*valid_data)).detach().numpy())
        neg_score = float(torch.mean(model.score(*neg_data)).detach().numpy())
        print('SpikE {}: ETA {}min; train: {}, valid: {}, neg: {}'.format(k, np.round(estimate_time,2), train_score, valid_score, neg_score))
      else:
        print('SpikE {}: ETA {}min'.format(k, np.round(estimate_time,2)))
    
    optimizer.zero_grad()
    databatch = batcher.next_betch()
    prediction = model.score(*databatch)
    weight_reg = model.entities.weight_loss()
    loss = loss_fun(prediction, torch.tensor(databatch[-1]))
    loss = loss + delta*weight_reg
    loss.backward()
    optimizer.step()
    model.update_embeddings()


def train_and_evaluate(optimizer, batcher, model, delta, steps, eval_points, datapath, data):
    train_data, valid_data = data[0], data[1]
    loss_fun = torch.nn.SoftMarginLoss()
    starttime = time.time()
    for k in range(steps):
        # Evaluate current Model
        if k in eval_points:
            print('{}:\n'.format(k))
            _, ranks_modes, hitsAt = get_rank_sp(model, train_data, 'train', datapath)
            print('train ~~~~SP~~~~ hits@1: {0:.4f} __hits@3: {1:.4f} __ mean: {2:.4f} __ MRR: {3:.4f}\n'.format(hitsAt[0], hitsAt[1], ranks_modes[3], ranks_modes[4]))
            _, ranks_modes, hitsAt = get_rank_po(model, train_data, 'train', datapath)
            print('train ~~~~PO~~~~ hits@1: {0:.4f} __hits@3: {1:.4f} __ mean: {2:.4f} __ MRR: {3:.4f}\n'.format(hitsAt[0], hitsAt[1], ranks_modes[3], ranks_modes[4]))
            _, ranks_modes, hitsAt = get_rank_sp(model, valid_data, 'valid', datapath)
            print('valid ~~~~SP~~~~ hits@1: {0:.4f} __hits@3: {1:.4f} __ mean: {2:.4f} __ MRR: {3:.4f}\n'.format(hitsAt[0], hitsAt[1], ranks_modes[3], ranks_modes[4]))
            _, ranks_modes, hitsAt = get_rank_po(model, valid_data, 'valid', datapath)
            print('valid ~~~~PO~~~~ hits@1: {0:.4f} __hits@3: {1:.4f} __ mean: {2:.4f} __ MRR: {3:.4f}\n'.format(hitsAt[0], hitsAt[1], ranks_modes[3], ranks_modes[4]))

            eta = (time.time()-starttime)/(k+1)*(steps-k)/60.
            print('ETA {0:.2f}min \n\n'.format(eta))
        
        # Update current model on new batch
        optimizer.zero_grad()
        databatch = batcher.next_batch()
        prediction = model.score(databatch[0], databatch[1], databatch[2])
        weight_reg = model.entities.weight_loss()
        loss = loss_fun(prediction, torch.tensor(databatch[-1]))
        loss = loss + delta*weight_reg
        loss.backward()
        optimizer.step()
        model.update_embeddings()


## Neuron Model

In [134]:
from copy import deepcopy

class nLIF:
    def __init__(self, node_dim, embed_dim, input_dim, tau, maxSpan, seed):
        torch.manual_seed(seed)
        # neuron params
        self.tau = tau
        self.threshold = 1.
        self.winit = 0.2
        self.maxSpan = maxSpan/2

        # network params
        self.node_dim = node_dim
        self.embed_dim = embed_dim
        self.input_dim = input_dim

        # Generate random input spikes with max height of maxSpan/2
        self.input_spikes = (torch.rand(input_dim)-0.5) * maxSpan
        self.input_spikes = self.input_spikes.sort().values
        self.input_exp = torch.exp(self.input_spikes/self.tau)
        # create spike sequence
        self.spike_seq = torch.zeros((self.input_dim, self.input_dim))
        for i in range(self.input_dim):
            for j in range(self.input_dim): # create exp input spike at time step i
                if i >= j:
                    self.spike_seq[i][j] = self.input_exp[j]
        self.spike_mask = (self.spike_seq > 0)*1. # binary value
        
        # create input -> population weights
        self.weights = self.init_weights()
        self.spike_times = self.get_spike_times()

    def init_weights(self):
        # Initialize weights using normal distribution.
        weights = torch.nn.Embedding(self.node_dim, self.embed_dim*self.input_dim)
        for i in range(self.node_dim):
            # guarantee that all populations spike once
            while bool((weights.weight.view(-1,self.embed_dim,self.input_dim)[i].sum(-1) < 1).any()) == True:
                weights.weight.data[i] = torch.normal(mean = self.winit, std = 1, size = (1, self.embed_dim*self.input_dim))
        return weights

    def get_spike_times(self):
        spike_times = torch.zeros((self.node_dim, self.embed_dim))+self.maxSpan+0.5
        weights = self.weights.weight.view(-1, self.embed_dim, self.input_dim)
        for j in range(self.input_dim):
            wSumExp = torch.matmul(weights, self.spike_seq[j])
            wSum = torch.matmul(weights, self.spike_mask[j])
            wSumDiff = wSum - self.threshold
            wQuotient = wSumExp/(wSumDiff+ 1e-10) + 1e-10 # for stability
            times = self.tau * torch.log(wQuotient)
            if j < self.input_dim-1:
                # check condition for spiking
                # 1. has not spiked yet
                # 2. weights sum over threshold
                # 3. next input spike does not hinder firing
                new_spikes = (spike_times == self.maxSpan+0.5)*(wSum > self.threshold)*(wQuotient < self.input_exp[j+1])
            else: # for last spike
                new_spikes =  (spike_times == self.maxSpan+0.5)*(wSum > self.threshold)
            spike_times[new_spikes] = times[new_spikes]
        return spike_times

    def update_embeddings(self):
        ''' 
        Spike embeddings are stored after calculation to reduce compute time 
        when evaluating several times without training updates.
        '''
        self.spike_times = self.get_spike_times()

    def embeddings(self, s_embs, o_embs):
        '''
        Read out embeddings.
        Input: list of subjects, list of objects
        Output: list of subject embeddings, list of object embeddings
        '''
        s_embs = torch.tensor(s_embs).long()
        o_embs = torch.tensor(o_embs).long()
        return self.spike_times[s_embs], self.spike_times[o_embs]

    def weight_loss(self):
        '''
        Regularization term that increases weights when their sum is below the 
        threshold value.
        Output: regularization term
        '''
        weight_norm = self.weights.weight.view(-1, self.embed_dim, self.input_dim).sum(-1)
        return ((self.threshold-weight_norm)*(weight_norm < self.threshold)).sum()

    def _integrate_model_using_Euler(self):
        '''
        Euler method to obtain spike times. Used to cross-check the analytical solution.
        Output: spike times, membrane potentials over time
        '''
        results = []
        spike_times = torch.zeros((self.node_dim, self.embed_dim))+1.5
        voltage = torch.zeros((self.node_dim, self.embed_dim))
        weights = self.weights.weight.view(-1, self.embed_dim, self.input_dim)

        # Euler integration from t = -maxSpan to maxSpan to solve ODE
        # return spike times + voltage traces
        t = -self.maxSpan
        dt = 0.01

        results.append(deepcopy(voltage.detach().numpy()))
        while t <= self.maxSpan:
            voltage.data = torch.matmul(weights, (1-torch.exp(-(t-self.input_spikes)/self.tau)) * (t > self.input_spikes))
            t += dt
            results.append(deepcopy(voltage.detach().numpy()))

            new_spikes = (t-dt-1.5)*(voltage > self.threshold)+1.5
            spiked = spike_times == 1.5
            spike_times = torch.logical_not(spiked)*spike_times + spiked*new_spikes

        return spike_times, results


torch.set_printoptions(precision=2)
nlif = nLIF(4, 4, 2, 0.2, 10, 0)
print(nlif.spike_times)

tensor([[2.75, 0.04, 0.14, 0.10],
        [0.40, 0.12, 3.09, 0.30],
        [2.98, 0.31, 3.00, 0.17],
        [0.11, 2.98, 0.14, 0.19]], grad_fn=<IndexPutBackward0>)


## Network Model

In [135]:
## Scoring Helper Functions
from torch.nn import functional as F

def ASYmmetric_score(s_emb, o_emb, p_emb):
  ''' TransE score '''
  return -F.pairwise_distance(s_emb - o_emb, p_emb, p=1)

def SYMmetric_score(s_emb, o_emb, p_emb):
  ''' Symmetric version of TransE score '''
  return -F.pairwise_distance((s_emb - o_emb).abs(), p_emb, p=1)

In [136]:
class SpikE_Scorer_S:
    def __init__(self, node_dim, embed_dim, input_dim, relation_dim, tau, maxSpan = 2, seed = 1231245):
        torch.manual_seed(seed)
        self.entities = nLIF(node_dim, embed_dim, input_dim, tau, maxSpan, seed)
        self.predicates = torch.nn.Embedding(relation_dim, embed_dim)
        self.node_num = node_dim
    
    def score(self, subj, pred, obj):
        '''
        Calculate the score of a list of triples.
        Input: list of subjects, predicates, objects, e.g., [s0, s1, ...], [p0, p1, ...], [o0, o1, ...]
        Output: list of scores
        '''
        s_emb, o_emb = self.entities.embeddings(subj, obj)
        p_emb = self.predicates(torch.tensor(pred).long())
        return SYMmetric_score(s_emb, o_emb, p_emb)

    def update_embeddings(self):
        self.entities.update_embeddings() # pass to nLIF
    
    def save(self, savepath, appdix = ''):
        ''' Save some stuff. '''
        pred_embs = self.predicates.weight.data.detach().numpy()
        ent_embs = self.entities.get_spike_times().detach().numpy()
        np.save('{}/weights_{}.npy'.format(savepath, appdix), self.entities.weights.weight.data.detach().numpy())
        np.save('{}/input_spikes_{}.npy'.format(savepath, appdix), self.entities.input_spikes.detach().numpy())
        np.save('{}/predicate_embeddings_{}.npy'.format(savepath, appdix), pred_embs)
        np.save('{}/entity_embeddings_{}.npy'.format(savepath, appdix), ent_embs)

        plt.close()
        for j in range(50): plt.vlines(ent_embs[j], j+0.1, (j+1)-0.1)
        plt.savefig('{}/entity_embeddings_{}.png'.format(savepath, appdix))

        plt.close()
        for j in range(len(pred_embs)): plt.vlines(pred_embs[j], j+0.1, (j+1)-0.1)
        plt.savefig('{}/predicate_embeddings_{}.png'.format(savepath, appdix))

class SpikE_Scorer_AS(SpikE_Scorer_S):
    def __init__(self, node_dim, embed_dim, input_dim, relation_dim, tau, maxSpan = 2, seed = 1231245):
        super().__init__(node_dim, embed_dim, input_dim, relation_dim, tau, maxSpan, seed)
    
    def score(self, subj, pred, obj):
        '''
        Calculate the score of a list of triples.
        Input: list of subjects, predicates, objects, e.g., [s0, s1, ...], [p0, p1, ...], [o0, o1, ...]
        Output: list of scores
        '''
        s_emb, o_emb = self.entities.embeddings(subj, obj)
        p_emb = self.predicates(torch.tensor(pred).long())

        return ASYmmetric_score(s_emb, o_emb, p_emb)



## New Stuff

In [137]:
from argparse import Namespace

In [138]:
datapath = '/content/gdrive/MyDrive/SpikeKG/data/Countries_S1'
ent2id = get_ent2id(datapath, 'entity_ids.del')
rel2id = get_rel2id(datapath, 'relation_ids.del')
id2ent = get_id2ent(datapath, 'entity_ids.del')
train_data, valid_data, node_num, predicate_num = load_data(datapath)

# print(f"ent2id:{dict(list(ent2id.items())[:4])}...")
# print(f"rel2id:{dict(list(rel2id.items())[:4])}...")
# print(f"id2ent:{dict(list(id2ent.items())[:4])}...")

# print(f"train-data shape:{np.shape(train_data)}, valid-data shape: {np.shape(valid_data)}")
# print(f"total # ents: {node_num}, total # rels: {predicate_num}")

if not os.path.exists('{}/valid_filter_po.npy'.format(datapath)):
    print('Creating data for filtered metrics...')
    filter_data(datapath)

params_dict = {
  'embed_dim': 40,
  'input_dim': 40,
  'tau': 0.5,
  'batchsize': 64,
  'delta': 0.001,
  'lr': .1,
  'L2': 0.,
  'steps': 801,
  'neg_samples': 10,
  'maxSpan': 2,
}
params = Namespace(**params_dict)

eval_points = list(range(0, 1001, 200))
seed = np.random.randint(1e8)

batcher = batch_provider(train_data, params.batchsize, params.neg_samples, seed)

model = SpikE_Scorer_AS(node_num, predicate_num, params.embed_dim, params.input_dim, 
                        params.tau, params.maxSpan, seed)
optimizer = torch.optim.Adagrad([model.entities.weights.weight, model.predicates.weight], 
                                lr=params.lr, weight_decay = params.L2)
train_and_evaluate(optimizer, batcher, model, params.delta, params.steps, eval_points, 
      datapath = datapath, data = [train_data, valid_data])



Creating data for filtered metrics...


  arr = np.asanyarray(arr)


0:

train ~~~~SP~~~~ hits@1: 0.0009 __hits@3: 0.0054 __ mean: 139.2286 __ MRR: 0.0180

train ~~~~PO~~~~ hits@1: 0.0036 __hits@3: 0.0072 __ mean: 131.1206 __ MRR: 0.0199

valid ~~~~SP~~~~ hits@1: 0.0000 __hits@3: 0.0000 __ mean: 175.0833 __ MRR: 0.0112

valid ~~~~PO~~~~ hits@1: 0.0000 __hits@3: 0.0000 __ mean: 103.6667 __ MRR: 0.0179

ETA 14.17min 


200:

train ~~~~SP~~~~ hits@1: 0.0648 __hits@3: 0.1323 __ mean: 35.3384 __ MRR: 0.1558

train ~~~~PO~~~~ hits@1: 0.0198 __hits@3: 0.0576 __ mean: 60.0657 __ MRR: 0.0780

valid ~~~~SP~~~~ hits@1: 0.1667 __hits@3: 0.3750 __ mean: 7.0417 __ MRR: 0.3475

valid ~~~~PO~~~~ hits@1: 0.0000 __hits@3: 0.0000 __ mean: 98.2083 __ MRR: 0.0243

ETA 0.26min 


400:

train ~~~~SP~~~~ hits@1: 0.0819 __hits@3: 0.1872 __ mean: 28.2574 __ MRR: 0.1804

train ~~~~PO~~~~ hits@1: 0.0171 __hits@3: 0.0693 __ mean: 53.0963 __ MRR: 0.0880

valid ~~~~SP~~~~ hits@1: 0.1667 __hits@3: 0.3750 __ mean: 6.8750 __ MRR: 0.3420

valid ~~~~PO~~~~ hits@1: 0.0000 __hits@3: 0.0000 