In [6]:
CUDA_INDEX = 1
NAME = 'DBLP'
CLASSES = 8

In [2]:
import sys
sys.path.insert(0, '../..')
sys.path.insert(0, '../../pyged/lib')

In [7]:
import os
import pickle
import random
import time

import IPython as ipy
import matplotlib.pyplot as plt
import numpy as np
import torch
torch.cuda.set_device(CUDA_INDEX)
torch.backends.cudnn.benchmark = True
import torch.optim
import torch_geometric as tg
import torch_geometric.data
from tqdm.auto import tqdm

from neuro import config, datasets, index, metrics, models, train, utils, viz
import pyged

from importlib import reload
reload(config)
reload(datasets)
reload(index)
reload(metrics)
reload(models)
reload(pyged)
reload(train)
reload(utils)
reload(viz)

<module 'neuro.viz' from '../../neuro/viz.py'>

In [4]:
import ext.H2MN as h2mn
import ext.H2MN.models
from neuro import h2mn_utils

reload(h2mn)
reload(h2mn_utils)

<module 'neuro.h2mn_utils' from '../../neuro/h2mn_utils.py'>

In [5]:
# graphs = utils.remove_extra_attrs(utils.label_graphs(tg.datasets.CitationFull(root=f'../data/{NAME}/tg', name=f'{NAME}')))

In [5]:
# tic = time.time()
# nbrs = datasets.decompose(graphs, n_hops=2)
# toc = time.time()
# torch.save(nbrs, f'../data/{NAME}/nbrs.pt')
# tqdm.write(f'neighborhood decomposition time: {toc-tic:.3} s')
nbrs = torch.load(f'/data/rishabh/neurosim/final/data/{NAME}/nbrs.pt')

In [8]:
# queries = datasets.make_queries(nbrs, n_queries=5, n_hops=2, trav_prob=0.5, node_lim=25)
# torch.save(queries, f'../retrvs/{NAME}/queries.pt')
queries = torch.load(f'../retrvs/{NAME}/queries.pt', map_location='cpu')

In [10]:
model = models.NormSEDModel(8, CLASSES, 64, 64)
model.load_state_dict(torch.load(f'../runlogs/{NAME}/1628504752.911032/best_model.pt', map_location='cpu'))
model = model.eval()
model = model.to(config.device)

In [11]:
rw_path = f'/home/neuroGQuery/siddharth/H2MN/ckpt/{NAME}-RW'
rw_args = torch.load(os.path.join(rw_path, 'args.pt'), map_location='cpu')
rw_model = h2mn.models.Model(rw_args)
rw_model.load_state_dict(torch.load(os.path.join(rw_path, 'model.pth'), map_location='cpu'))
rw_model = rw_model.eval()
rw_model = rw_model.to(config.device)
rw_dist_fn = h2mn_utils.DistancePredictor(rw_model, batch_size=16)

In [9]:
ne_path = f'/home/neuroGQuery/siddharth/H2MN/ckpt/{NAME}_K1-NE'
ne_args = torch.load(os.path.join(ne_path, 'args.pt'), map_location='cpu')
ne_model = h2mn.models.Model(ne_args)
ne_model.load_state_dict(torch.load(os.path.join(ne_path, 'model.pth'), map_location='cpu'))
ne_model = ne_model.eval()
ne_model = ne_model.to(config.device)
ne_dist_fn = h2mn_utils.DistancePredictor(ne_model, batch_size=16)

In [13]:
tic = time.time()
target_embs_list = []
batch_size = 4096
with torch.no_grad():
    for i in tqdm(range(0,len(nbrs),batch_size), 'target batches'):
        target_embs_list.append(model.embed_model(tg.data.Batch.from_data_list(nbrs[i:i+batch_size]).to(config.device)))
target_embs = torch.cat(target_embs_list)
toc = time.time()
tqdm.write(f'total target embedding time: {toc-tic:.3} s')
tqdm.write(f'per target embedding time: {(toc-tic)/len(nbrs):.3} s')
torch.save(target_embs, f'../retrvs/{NAME}/target_embs.pt')
target_embs = torch.load(f'../retrvs/{NAME}/target_embs.pt', map_location='cpu')

target batches:   0%|          | 0/406 [00:00<?, ?it/s]

total target embedding time: 3.19e+02 s
per target embedding time: 0.000192 s


In [14]:
tic = time.time()
with torch.no_grad():
    query_embs = model.embed_model(tg.data.Batch.from_data_list(queries).to(config.device))
toc = time.time()
tqdm.write(f'total query embedding time: {toc-tic:.3} s')
tqdm.write(f'per query embedding time: {(toc-tic)/len(queries):.3} s')
torch.save(query_embs, f'../retrvs/{NAME}/query_embs.pt')
query_embs = torch.load(f'../retrvs/{NAME}/query_embs.pt', map_location='cpu')

total query embedding time: 0.00408 s
per query embedding time: 0.000815 s


In [15]:
temb = target_embs
qemb = query_embs

In [25]:
tobj_h2mn = nbrs[:10000]
qobj_h2mn = queries[:1]

In [34]:
len(nbrs)

1661871

## Vectorised Linear Scan on GPU

In [21]:
config.device = torch.device('cuda')
tic = time.time()
index_str = index.FastLinearScan(temb, dist_fn=utils.norm_sed_func)
toc = time.time()
tqdm.write(f'preprocessing time: {toc-tic:.3} s')

fast (vectorised) linear scan index
config.device: cuda
preprocessing time: 0.00844 s


In [22]:
k = 10
tic = time.time()
for q in tqdm(qemb, f'k = {k} | queries'):
    index_str.knn_query(q, k, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

k = 10 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

time per query in s: 0.126


In [23]:
r = 2
tic = time.time()
for q in tqdm(qemb, f'r = {r} | queries'):
    index_str.range_query(q, r, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

r = 2 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

time per query in s: 0.0696


In [30]:
del index_str
torch.cuda.empty_cache()

In [20]:
config.device = torch.device('cuda')
tic = time.time()
index_str = index.FastLinearScan(tobj_h2mn, dist_fn=ne_dist_fn)
toc = time.time()
tqdm.write(f'preprocessing time: {toc-tic:.3} s')

fast (vectorised) linear scan index
config.device: cuda
preprocessing time: 0.000832 s


In [21]:
k = 10
tic = time.time()
for q in tqdm(qobj_h2mn, f'k = {k} | queries'):
    index_str.knn_query(q, k, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/len(qobj_h2mn):.3}')

k = 10 | queries:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/625 [00:00<?, ?it/s]

RuntimeError: mat1 dim 1 must match mat2 dim 0

In [21]:
r = 2
tic = time.time()
for q in tqdm(qobj_h2mn, f'r = {r} | queries'):
    index_str.range_query(q, r, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/len(qobj_h2mn):.3}')

r = 2 | queries:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/103867 [00:00<?, ?it/s]

time per query in s: 8.55e+03


In [22]:
del index_str
torch.cuda.empty_cache()

In [31]:
config.device = torch.device('cuda')
tic = time.time()
index_str = index.FastLinearScan(tobj_h2mn, dist_fn=ne_dist_fn)
toc = time.time()
tqdm.write(f'preprocessing time: {toc-tic:.3} s')

fast (vectorised) linear scan index
config.device: cuda
preprocessing time: 0.000741 s


In [32]:
k = 10
tic = time.time()
for q in tqdm(qobj_h2mn, f'k = {k} | queries'):
    index_str.knn_query(q, k, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/len(qobj_h2mn):.3}')

k = 10 | queries:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/625 [00:00<?, ?it/s]

time per query in s: 58.9


In [33]:
r = 2
tic = time.time()
for q in tqdm(qobj_h2mn, f'r = {r} | queries'):
    index_str.range_query(q, r, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/len(qobj_h2mn):.3}')

r = 2 | queries:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/625 [00:00<?, ?it/s]

time per query in s: 53.6


In [None]:
del index_str
torch.cuda.empty_cache()

## Unvectorised Linear Scan on CPU (threads=1)

In [23]:
config.device = torch.device('cpu')
torch.set_num_threads(1)
tic = time.time()
index_str = index.LinearScan(temb, dist_fn=utils.norm_sed_func)
toc = time.time()
tqdm.write(f'preprocessing time: {toc-tic:.3} s')
index_str = index_str.to(config.device)

slow (pythonic) linear scan index
config.device: cpu
preprocessing time: 0.00279 s


In [24]:
k = 10
tic = time.time()
for q in tqdm(qemb, f'k = {k} | queries'):
    index_str.knn_query(q, k, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

k = 10 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

targets:   0%|          | 0/1661871 [00:02<?, ?it/s]

targets:   0%|          | 0/1661871 [00:01<?, ?it/s]

targets:   0%|          | 0/1661871 [00:11<?, ?it/s]

targets:   0%|          | 0/1661871 [00:01<?, ?it/s]

targets:   0%|          | 0/1661871 [00:01<?, ?it/s]

time per query in s: 50.4


In [25]:
r = 2
tic = time.time()
for q in tqdm(qemb, f'r = {r} | queries'):
    index_str.range_query(q, r, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

r = 2 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

targets:   0%|          | 0/1661871 [00:13<?, ?it/s]

targets:   0%|          | 0/1661871 [00:01<?, ?it/s]

targets:   0%|          | 0/1661871 [00:13<?, ?it/s]

targets:   0%|          | 0/1661871 [00:01<?, ?it/s]

targets:   0%|          | 0/1661871 [00:15<?, ?it/s]

time per query in s: 48.0


In [26]:
del index_str
torch.cuda.empty_cache()

## Unvectorised (Metric-)Tree on CPU (threads=1)

In [27]:
config.device = torch.device('cpu')
torch.set_num_threads(1)
tic = time.time()
index_str = index.AsymTree(temb, dist_fn=utils.norm_sed_func, max_leaf_size=128)
toc = time.time()
tqdm.write(f'preprocessing time: {toc-tic:.3} s')
index_str = index_str.to(config.device)

construct (metric-)tree for asymmetric/symmetric distance function
config.device: cpu


  0%|          | 0/1661871 [00:00<?, ?it/s]

distance computations: 27252518
internal nodes: 10486 / 41945
leaf nodes: 31459 / 41945
big leaf nodes: 101 / 31459
max leaf size: 949
preprocessing time: 16.8 s


In [28]:
k = 10
tic = time.time()
for q in tqdm(qemb, f'k = {k} | queries'):
    index_str.knn_query(q, k, slow=True, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

k = 10 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

time per query in s: 18.6


In [29]:
r = 2
tic = time.time()
for q in tqdm(qemb, f'r = {r} | queries'):
    index_str.range_query(q, r, slow=True, verbose=False)
toc = time.time()
tqdm.write('time per query in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

r = 2 | queries:   0%|          | 0/5 [00:00<?, ?it/s]

time per query in s: 20.9


In [30]:
del index_str
torch.cuda.empty_cache()

## Alignment

In [15]:
tic = time.time()
k = 10
topk = []
for q in tqdm(qemb, f'k = {k} | queries'):
    topk.append(index_str.knn_query(q, k, verbose=False))
toc = time.time()
tqdm.write('time in s: ' f'{(toc-tic)/qemb.shape[0]:.3}')

k = 10 | queries:   0%|          | 0/10 [00:00<?, ?it/s]

time in s: 0.00447


In [16]:
tic = time.time()
for q, ts in zip(tqdm(queries, 'queries'), topk):
    for ti in tqdm(ts, 'targets'):
        node_map, (lb, ub) = pyged.sed_align(utils.to_pyged(q), utils.to_pyged(nbrs[ti]), 'f2', '')
        print(lb, ub, end = ' | ')
    print()
toc = time.time()
print(toc-tic)

queries:   0%|          | 0/10 [00:00<?, ?it/s]

targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 2.0 2.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

1.0 1.0 | 0.0 0.0 | 1.0 1.0 | 1.0 1.0 | 1.0 1.0 | 0.0 0.0 | 1.0 1.0 | 1.0 1.0 | 2.0 2.0 | 1.0 1.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 
528.0619425773621


In [17]:
tic = time.time()
for q, ts in zip(tqdm(queries, 'queries'), topk):
    for ti in tqdm(ts, 'targets'):
        node_map, (lb, ub) = pyged.sed_align(utils.to_pyged(q), utils.to_pyged(nbrs[ti]), 'f2', '--threads 64')
        print(lb, ub, end = ' | ')
    print()
toc = time.time()
print(toc-tic)

queries:   0%|          | 0/10 [00:00<?, ?it/s]

targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 2.0 2.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

1.0 1.0 | 0.0 0.0 | 1.0 1.0 | 1.0 1.0 | 1.0 1.0 | 0.0 0.0 | 1.0 1.0 | 1.0 1.0 | 2.0 2.0 | 1.0 1.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 


targets:   0%|          | 0/10 [00:00<?, ?it/s]

1.0 1.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 0.0 0.0 | 1.0 1.0 | 0.0 0.0 | 
314.4020688533783
