# Assign positives and negatives 

In [1]:
import json
import threading
import concurrent.futures
import bz2
import csv
import time
import multiprocessing as mp

import pandas
import py2neo

In [2]:
from tqdm import tqdm

## Connect to neo4j server

In [3]:
# Override the default py2neo timeout
py2neo.packages.httpstream.http.socket_timeout = 1e8

In [4]:
with open('servers.json') as read_file:
    instances = json.load(read_file)

name_to_neo = dict()
for instance in instances:
    neo = py2neo.Graph(
        host="localhost", http_port=instance['port']+1,
        bolt_port=instance['port'], bolt=True
    )
    
    name_to_neo[instance['name']] = neo
    
#     uri = 'http://localhost:{}/db/data/'.format(instance['port'])
#     name_to_neo[instance['name']] = py2neo.Graph(uri)

## Read metapaths

In [5]:
with open('data/metapaths.json') as read_file:
    metapaths = json.load(read_file)

metapaths.sort(key=lambda x: x['join_complexities'][0])
len(metapaths)

9

## Read hetnet-compound-disease pairs

In [6]:
part_df = pandas.read_table('data/partitions.tsv')
part_df['neo'] = part_df.hetnet.map(name_to_neo)
parts = list(part_df.itertuples())
part_df.head(2)

Unnamed: 0,hetnet,compound_id,disease_id,status,primary,neo
0,rephetio-v2.0,DB00014,DOID:10283,1,1,<Graph uri='http://localhost:7501/db/data/'>
1,rephetio-v2.0_perm-1,DB00014,DOID:10283,0,0,<Graph uri='http://localhost:7511/db/data/'>


## Set up queries

In [7]:
# Total number of queries
total_queries = len(metapaths) * len(part_df)
'{:,}'.format(total_queries)

'271,575'

In [8]:
parts = [row for row in part_df.itertuples()]

def generate_parameters(max_elems=None):
    """Generate compound, disease, metapath combinations"""
    n = 0
    for metapath_dict in metapaths:
        metapath = metapath_dict['abbreviation']
        query = metapath_dict['dwpc_query']
        for part_info in parts:
            if max_elems is not None and n == max_elems:
                break
            yield {
                'neo': part_info.neo,
                'hetnet': part_info.hetnet,
                'compound_id': part_info.compound_id,
                'disease_id': part_info.disease_id,
                'metapath': metapath,
                'query': query,
                'w': 0.4,
            }
            n += 1

In [9]:
def compute_dwpc(neo, hetnet, query, metapath, compound_id, disease_id, w):
    """Execute the neo4j query and write results to file"""
    start = time.time()
    results = neo.run(query, source=compound_id, target=disease_id, w=w)
    record = results.one
    seconds = '{0:.4g}'.format(time.time() - start)
    row = hetnet, compound_id, disease_id, metapath, record['PC'], w, '{0:.6g}'.format(record['DWPC']), seconds
    with writer_lock:
        writer.writerow(row)

## Execute queries

In [10]:
%%time

# Parameters
workers = mp.cpu_count()
max_elems = None

# Prepare writer
path = 'data/dwpc.tsv.bz2'
write_file = bz2.open(path, 'wt')
writer = csv.writer(write_file, delimiter='\t')
writer.writerow(['hetnet', 'compound_id', 'disease_id', 'metapath', 'PC', 'w', 'DWPC', 'seconds'])

# Create ThreadPoolExecutor
executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
writer_lock = threading.Lock()

# Submit jobs
n_queries = 0
for params in tqdm(generate_parameters(max_elems), total = total_queries):
    while executor._work_queue.qsize() > 10000:
#         print('Submitted queries: {} ({:.4%})'.format(n_queries, n_queries / total_queries), end='\r')
        time.sleep(1)
    executor.submit(compute_dwpc, **params)
    n_queries += 1

# Shutdown and close
executor.shutdown()
write_file.close()

100%|██████████| 271575/271575 [08:02<00:00, 563.30it/s] 


CPU times: user 5min 40s, sys: 19.5 s, total: 6min
Wall time: 8min 50s


In [11]:
n_queries

271575