In [None]:
import scipy
print(scipy.__version__)
%matplotlib inline

In [None]:
import cellgrid 
import numpy as np
from MDAnalysis.lib.distances import self_distance_array,distance_array
from scipy.spatial.distance import squareform
from initialization import init_uniform
from MDAnalysis.lib.pkdtree import PeriodicKDTree
import itertools

In [None]:
def bf_select(box, points, maxdist):
    distance = self_distance_array(points,box)
    distance = scipy.spatial.distance.squareform(distance)
    distance[np.tril_indices(distance.shape[0])] = (1.1*maxdist)
    mask = np.where((distance < maxdist))
    out = tuple((np.column_stack((mask[0],mask[1])),distance[mask]))
    return out
    
def cg_select(box, points, maxdist):
    indx,dist = cellgrid.capped_self_distance_array(points,maxdist,box=box)
    mask = np.where(dist<maxdist)
    return indx[mask],dist[mask]

def kdtree_distance(box,points,maxdist):
    kdtree = PeriodicKDTree(box,bucket_size=10)
    kdtree.set_coords(points)
    pair,distances = list(),list()
    for idx,centers in enumerate(points):
        kdtree.search(centers,maxdist)
        indices = kdtree.get_indices()
        indices = [i for i in indices if i > idx]
        distances += list(distance_array(centers.reshape((1,3)),points[indices],box).flatten())
        pair += list(itertools.product(([idx]),indices))    
    return tuple((np.asarray(pair),np.asarray(distances)))

In [None]:
box = np.array([100,100,100,90,90,90],dtype=np.float32)
box1, points = init_uniform(box,Npoints=1000)
maxdist = 10.0

In [None]:
#Brute Force
bf = bf_select(box,points,maxdist)
print(bf)

In [None]:
#CellGrid Selection
cg = cg_select(box1,points,maxdist)
print(cg)

In [None]:
#KDTree selections
kd = kdtree_distance(box,points,maxdist)
print(kd,"type:",type(kd))

In [None]:
#Testing
#Number of contacts
np.testing.assert_equal(kd[1].shape,cg[1].shape) #Comparing total number of pairs with distance array
np.testing.assert_equal(bf[1].shape,cg[1].shape)

#Array Values 
np.testing.assert_array_almost_equal(np.sort(bf[1].astype(np.float64)),np.sort(cg[1]),decimal=5) #Comparing array values
np.testing.assert_almost_equal(bf[1],kd[1]) 

In [None]:
##Benchmarking - Time ------------> Number of particles
Npoints = np.logspace(2,5,num=10,dtype='int32')
box = np.array([100,100,100,90,90,90],dtype=np.float32)
maxdist = 10.0

In [None]:
time_kdpair,time_bfpair,time_cgpair = [],[],[]
func_dict = dict(KDtree=kdtree_distance,BruteForce=bf_select,Cellgrid=cg_select)
time_dict = dict(t_kdtree=time_kdpair,t_bruteforce=time_bfpair,t_cellgrid=time_cgpair)
for num in Npoints:
    box1, points = init_uniform(box,Npoints=num)
    kdpair = %timeit -q -o -n 10 kdtree_distance(box,points,maxdist)
    time_kdpair.append([kdpair.average,kdpair.stdev])
    bfpair = %timeit -q -o -n 10 bf_select(box,points,maxdist)
    time_bfpair.append([bfpair.average,bfpair.stdev])
    cgpair = %timeit -q -o -n 10 cg_select(box1,points,maxdist)
    time_cgpair.append([cgpair.average,cgpair.stdev])

In [None]:
kd_plot,bf_plot,cg_plot = np.array(time_kdpair).T,np.array(time_bfpair).T,np.array(time_cgpair).T
kd_error,bf_error,cg_error = np.array([-kd_plot[1],kd_plot[1]]),np.array([-bf_plot[1],bf_plot[1]]),np.array([-cg_plot[1],cg_plot[1]])

In [None]:
import matplotlib
import matplotlib.pyplot as plt

In [None]:
plotvals = dict(KDtree=kd_plot,BruteForce=bf_plot,CellGrid=cg_plot)
errorvals = dict(KDtree=kd_error,BruteForce=bf_error,CellGrid=cg_error)

#print(plotvals)
for key in plotvals:
    plt.errorbar(Npoints,plotvals[key][0], yerr=errorvals[key], marker='^',fmt='-o',label=(key))
    plt.legend()