In [1]:
# Helpful statement for debugging, prints the thing entered as x and the output, i.e.,
# debugPrint(1+1) will output '1+1 [int] = 2'
import inspect
import re
def debugPrint(x):
    frame = inspect.currentframe().f_back
    s = inspect.getframeinfo(frame).code_context[0]
    r = re.search(r"\((.*)\)", s).group(1)
    print("{} [{}] = {}".format(r,type(x).__name__, x))
    
    
import os
import os, sys
# sys.path.append(os.path.join('~/dev/pytorchSPH/', "lib"))
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm import trange, tqdm
import yaml
%matplotlib notebook
import warnings
warnings.filterwarnings(action='once')
from datetime import datetime

import torch
from torch_geometric.nn import radius
from torch_geometric.nn import SplineConv, fps, global_mean_pool, radius_graph, radius
from torch_scatter import scatter

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from scipy.optimize import minimize
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm
from matplotlib.ticker import MaxNLocator
import matplotlib.ticker as mticker
from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
def genParticlesCentered(minCoord, maxCoord, radius, support, packing, dtype = torch.float32, device = 'cpu'):
    area = np.pi * radius**2
    
    gen_position = lambda r, i, j: torch.tensor([r * i, r * j], dtype=dtype, device = device)
        
    diff = maxCoord - minCoord
    center = (minCoord + maxCoord) / 2
    requiredSlices = torch.div(torch.ceil(diff / packing / support).type(torch.int64), 2, rounding_mode='floor')
    
    generatedParticles = []
#     print(requiredSlices)
    for i in range(-requiredSlices[0]-1, requiredSlices[0]+2):
        for j in range(-requiredSlices[1]-1, requiredSlices[1]+2):
            p = center
            g = gen_position(packing * support,i,j)
            pos = p + g
            if pos[0] <= maxCoord[0] + support * 0.2 and pos[1] <= maxCoord[1] + support * 0.2 and \
             pos[0] >= minCoord[0] - support * 0.2 and pos[1] >= minCoord[1] - support * 0.2:
                generatedParticles.append(pos)
                
    return torch.stack(generatedParticles)

In [4]:
minCoord = torch.tensor([-1,-1])
maxCoord = torch.tensor([1,1])


r = 0.005
area = np.pi * r**2
support = np.sqrt(area * 20 / np.pi)
packing = 0.399023


referenceParticles = genParticlesCentered(minCoord, maxCoord, r, support, packing)
referenceSupport = referenceParticles.new_ones(referenceParticles.shape[0]) * support

queryParticles = genParticlesCentered(minCoord, maxCoord, r, support, packing)
querySupport = queryParticles.new_ones(queryParticles.shape[0]) * support

positionJitter = 0.
supportJitter = 0.005
referenceParticles = referenceParticles + torch.normal(0, positionJitter, referenceParticles.shape,device = referenceParticles.device)
queryParticles = queryParticles + torch.normal(0, positionJitter, referenceParticles.shape,device = referenceParticles.device)

referenceSupport = referenceSupport + torch.normal(0, supportJitter, referenceSupport.shape,device = referenceParticles.device)
querySupport = querySupport + torch.normal(0, supportJitter, querySupport.shape,device = referenceParticles.device)


debugPrint(queryParticles.shape)

queryParticles.shape [Size] = torch.Size([50625, 2])


In [5]:
%load_ext wurlitzer

In [6]:
from torch.utils.cpp_extension import load
lltm_cpp = load(name="lltm_cpp2", sources=["cppSrc/neighSearch.cpp"], verbose=False)
# help(lltm_cpp)




In [7]:

# data = torch.ones(128) * 2

# lltm_cpp.buildNeighborList(data)
qMin2, hMax2, cells2, sortedPositions2, sortedSupport2, sortedIndices2, sort2 = lltm_cpp.sortPointSet(queryParticles, querySupport)



In [10]:
debugPrint(sortedIndices2)
# debugPrint(sortedIndices)
# debugPrint

sortedIndices2 [Tensor] = tensor([  96,   96,   96,  ..., 2160, 2160, 2160], dtype=torch.int32)




In [11]:
def setupPlot(axisLayout, plotScale = 4):
    extent = maxCoord - minCoord
    
    fig, ax2d = plt.subplots(axisLayout[0],axisLayout[1], figsize=(extent[0] * plotScale * 1.09 * axisLayout[0], extent[1] * plotScale * axisLayout[1]), squeeze = False)
    for xax in ax2d:
        for axis in xax:
            axis.set_xlim(minCoord[0]*1.1, maxCoord[0]*1.1)
            axis.set_ylim(minCoord[1]*1.1, maxCoord[1]*1.1)
            axis.axis('equal')
            axis.grid(False)
            axis.axvline(minCoord[0], c = 'black', ls= '--')
            axis.axvline(maxCoord[0], c = 'black', ls= '--')
            axis.axhline(minCoord[1], c = 'black', ls= '--')
            axis.axhline(maxCoord[1], c = 'black', ls= '--')

    return fig, ax2d



In [12]:
fig, axis = setupPlot([1,1], plotScale = 2)

positions = sortedPositions2.detach().cpu().numpy()
data = sortedIndices2.detach().cpu().numpy()

sc = axis[0,0].scatter(positions[:,0], positions[:,1],  c = data, s = 16 )
ax1_divider = make_axes_locatable(axis[0,0])
cax1 = ax1_divider.append_axes("right", size="4%", pad="1%")
cbar = fig.colorbar(sc, cax=cax1,orientation='vertical')
cbar.ax.tick_params(labelsize=8) 

fig.tight_layout()

<IPython.core.display.Javascript object>



In [13]:
fig, axis = setupPlot([1,1], plotScale = 2)

positions = queryParticles.detach().cpu().numpy()
data = querySupport.detach().cpu().numpy()

sc = axis[0,0].scatter(positions[:,0], positions[:,1],  c = data, s = 16 )
ax1_divider = make_axes_locatable(axis[0,0])
cax1 = ax1_divider.append_axes("right", size="4%", pad="1%")
cbar = fig.colorbar(sc, cax=cax1,orientation='vertical')
cbar.ax.tick_params(labelsize=8) 

fig.tight_layout()

<IPython.core.display.Javascript object>

In [14]:
@torch.jit.script
def sortPositions(queryParticles, querySupport):
    with record_function("sort"): 
        with record_function("sort - bound Calculation"): 
            hMax = torch.max(querySupport)
            qMin = torch.min(queryParticles,dim=0)[0] - hMax
            qMax = torch.max(queryParticles,dim=0)[0] + 2 * hMax

        #     debugPrint(qMin)
        #     debugPrint(qMax)

            qEx = qMax - qMin
    #     debugPrint(qEx / hMax)

        with record_function("sort - index Calculation"): 
            cells = torch.ceil(qEx / hMax).to(torch.int64)
        #     debugPrint(cells)

            indices = torch.ceil((queryParticles - qMin) / hMax).to(torch.int64)
        #     debugPrint(indices)
        #     debugPrint(indices.shape)

            particleIndices = torch.arange(queryParticles.shape[0], device = indices.device, dtype = torch.int64)
            linearIndices = indices[:,0] + cells[0] * indices[:,1]

        with record_function("sort - actual argsort"): 
            sort = torch.argsort(linearIndices)

        with record_function("sort - sorting data"): 
        #     debugPrint(particleIndices)
        #     debugPrint(linearIndices)
        #     debugPrint(sort)

        #     debugPrint(particleIndices)
        #     debugPrint(torch.any(particleIndices - particleIndices[sort] > 0))
            sortedIndices = linearIndices[sort]
            sortedPositions = queryParticles[sort,:]
            sortedSupport = querySupport[sort]

    #     debugPrint(particleIndices)

    return qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort

qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort = sortPositions(queryParticles, querySupport)



In [16]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],with_stack=True, profile_memory=True) as prof:    
    for i in range(128):
        qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort = sortPositions(queryParticles, querySupport)
#     for i in range(128):
#         neighs, distances, raddistances = neighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
#     for i in range(128):
#         neighs, distances, raddistances = jneighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
    
# print(prof.key_averages().table(sort_by='self_cpu_time_total'))
print(prof.key_averages().table(sort_by='cpu_time_total'))

-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        sortPositions         0.27%     741.000us       100.00%     271.689ms       2.123ms     173.04 Mb         512 b           128  
                                 sort         0.45%       1.217ms        99.03%     269.057ms       2.102ms     173.04 Mb      -8.92 Kb           128  
                sort - actual argsort         0.68%       1.845ms        72.33%     196.517ms       1.535ms      49.44 Mb     -10.66 Kb           128  
                        aten::argsort         0.19%     509.000us        71.58%     194.



In [17]:
with profile(activities=[ProfilerActivity.CPU],with_stack=True, profile_memory=False) as prof:    
    for i in range(128):
        qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort = lltm_cpp.sortPointSet(queryParticles, querySupport)
#     for i in range(128):
#         neighs, distances, raddistances = neighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
#     for i in range(128):
#         neighs, distances, raddistances = jneighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
    
# print(prof.key_averages().table(sort_by='self_cpu_time_total'))
print(prof.key_averages().table(sort_by='cpu_time_total'))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::argsort         0.12%     322.000us        70.65%     184.569ms       1.442ms           128  
                   aten::sort        67.07%     175.230ms        70.52%     184.247ms       1.439ms           128  
                  aten::copy_         7.03%      18.360ms         7.03%      18.360ms      17.930us          1024  
                    aten::sub         7.03%      18.355ms         7.03%      18.355ms      47.799us           384  
                    aten::max         4.59%      11.993ms         4.91%      12.833ms      50.129us           256  
                     aten::to         0.15%     396.000us         4.14% 



In [18]:
fig, axis = setupPlot([1,1], plotScale = 2)

positions = sortedPositions.detach().cpu().numpy()
data = sortedIndices.detach().cpu().numpy()

sc = axis[0,0].scatter(positions[:,0], positions[:,1],  c = data, s = 16 )
ax1_divider = make_axes_locatable(axis[0,0])
cax1 = ax1_divider.append_axes("right", size="4%", pad="1%")
cbar = fig.colorbar(sc, cax=cax1,orientation='vertical')
cbar.ax.tick_params(labelsize=8) 

fig.tight_layout()

<IPython.core.display.Javascript object>



In [19]:
hashMapLength = sortedPositions.shape[0] + 1
hashMapLength = 2**14
debugPrint(hashMapLength)

# qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort = lltm_cpp.sortPointSet(queryParticles, querySupport)
qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort = sortPositions(queryParticles, querySupport)
qMin2, hMax2, cells2, sortedPositions2, sortedSupport2, sortedIndices2, sort2 = lltm_cpp.sortPointSet(queryParticles, querySupport)

hashMapLength [int] = 16384




In [20]:
debugPrint(sortedIndices)
debugPrint(sortedIndices2)

debugPrint(sortedPositions)
debugPrint(sortedPositions2)

debugPrint(sortedSupport)
debugPrint(sortedSupport2)

sortedIndices [Tensor] = tensor([  96,   96,   96,  ..., 2160, 2160, 2160])
sortedIndices2 [Tensor] = tensor([  96,   96,   96,  ..., 2160, 2160, 2160], dtype=torch.int32)
sortedPositions [Tensor] = tensor([[-0.9725, -0.9547],
        [-0.9547, -0.9547],
        [-0.9547, -0.9636],
        ...,
        [ 0.9725,  0.9725],
        [ 0.9725,  0.9815],
        [ 0.9725,  0.9904]])
sortedPositions2 [Tensor] = tensor([[-0.9725, -0.9547],
        [-0.9547, -0.9547],
        [-0.9547, -0.9636],
        ...,
        [ 0.9725,  0.9725],
        [ 0.9725,  0.9815],
        [ 0.9725,  0.9904]])
sortedSupport [Tensor] = tensor([0.0235, 0.0295, 0.0185,  ..., 0.0191, 0.0237, 0.0197])
sortedSupport2 [Tensor] = tensor([0.0235, 0.0295, 0.0185,  ..., 0.0191, 0.0237, 0.0197])


In [21]:
debugPrint(cells)

cells [Tensor] = tensor([47, 47])




In [22]:
@torch.jit.script
def linearFrom2D(qID, cells):
    return qID[:,0] + cells[0] * qID[:,1]
@torch.jit.script
def hashFrom2D(qID, hashMapLength):
    return (qID[:,0] * 1212047 + qID[:,1] * 15233249) % hashMapLength
@torch.jit.script
def constructHashMap2(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength):
    hashMap = torch.ones(hashMapLength, dtype  = torch.int64) * -1
    hashSpan = torch.zeros(hashMapLength, dtype  = torch.int64)

    cellMap = torch.ones(sortedPositions.shape[0], dtype  = torch.int64) * -1
    cellSpan = torch.zeros(sortedPositions.shape[0], dtype  = torch.int64)
    
#     qIDs = torch.ceil((sortedPositions - qMin) / hMax).to(torch.int64)
#     linearIndices = linearFrom2D(qIDs, cells)
#     hashedIndices = hashFrom2D(qIDs, hashMapLength)

    counter = torch.arange(sortedIndices.shape[0] + 1, dtype = torch.int64)
    cellIndices = torch.empty(sortedIndices.shape[0] + 1, dtype=torch.bool)
    cellIndices[1:-1] = sortedIndices[:-1] != sortedIndices[1:]
    cellIndices[0] = True
    cellIndices[-1] = True
    # cellMap = torch.masked_select(linearIndices, cellIndices)
    compactedList = counter[cellIndices] #torch.masked_select(counter, cellIndices)
    # cellMap =  linearIndices[cellSpan]
    # cellSpan
    cellMap = sortedIndices[compactedList[:-1]]
    cellSpan = compactedList[1:] - compactedList[:-1]


    xIndices = cellMap % cells[0]
    yIndices = torch.div(cellMap, cells[0], rounding_mode='trunc')
    compactHashes = (xIndices * 1212047 + yIndices * 15233249) % hashMapLength
    
#     compactHashes = hashedIndices[compactedList[:-1]]
    # hashIndices = torch.arange(hashedIndices.shape[0] + 1)

    hashSortIndex = torch.argsort(compactHashes)
    sortedHashes = compactHashes[hashSortIndex]
    # collisionCounter = 

    counter = torch.arange(sortedHashes.shape[0] + 1, dtype = torch.int64)
    cellIndices = torch.empty(sortedHashes.shape[0] + 1, dtype=torch.bool)
    cellIndices[1:-1] = sortedHashes[:-1] != sortedHashes[1:]
    cellIndices[0] = True
    cellIndices[-1] = True
    # cellMap = torch.masked_select(linearIndices, cellIndices)
    compactedList = counter[cellIndices] #torch.masked_select(counter, cellIndices)
    # cellMap =  linearIndices[cellSpan]
    # cellSpan
    hashMapTemp = sortedHashes[compactedList[:-1]]
    hashSpanTemp = compactedList[1:] - compactedList[:-1]

#     debugPrint(hashMapTemp)
#     debugPrint(hashMapTemp.shape)

    hashMap[hashMapTemp] = compactedList[:-1]
    hashSpan[hashMapTemp] = hashSpanTemp

#     debugPrint(cellSpan)
#     debugPrint(cellSpan.shape)
#     debugPrint(cellMap)
#     debugPrint(cellMap.shape)
# 
#     debugPrint(compactHashes)
#     debugPrint(sortedHashes.shape)
    
    return hashMap, hashSpan, cellMap, cellSpan
    # debugPrint(compactHashes[sortedHashes])
    
hashMap, hashSpan, cellMap, cellSpan = constructHashMap2(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength)



In [36]:
debugPrint(hashMap[hashMap!= -1])
debugPrint(hashSpan[hashMap!= -1])

hashMap[hashMap!= -1] [Tensor] = tensor([   0,    1,    2,  ..., 1933, 1934, 1935])
hashSpan[hashMap!= -1] [Tensor] = tensor([1, 1, 1,  ..., 1, 1, 1])




In [42]:
debugPrint(cellMap.shape)

cellMap.shape [Size] = torch.Size([1936])




In [41]:
debugPrint(hashTable[hashTable[:,0] != -1,0])
debugPrint(hashTable[hashTable[:,0] != -1,1])

hashTable[hashTable[:,0] != -1,0] [Tensor] = tensor([   0,    1,    2,  ..., 1933, 1934, 1935])
hashTable[hashTable[:,0] != -1,1] [Tensor] = tensor([1, 1, 1,  ..., 1, 1, 1])




In [25]:

# prof.export_chrome_trace("trace.json")




In [26]:
# debugPrint(hashMap)
# debugPrint(hashSpan)



In [27]:
# hashTable, cumHash, hashIndexSorting, cellIndices, cumCell = constructHashMap(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength)

# debugPrint(hashTable[:,0])
# debugPrint(hashTable[:,1])



In [28]:
# diff = hashTable[:,1] - hashSpan
# debugPrint(diff)
# debugPrint(torch.sum(diff))
# diff = hashTable[:,0] - hashMap
# debugPrint(diff)
# debugPrint(torch.sum(diff))



In [29]:
# debugPrint(linearIndices)
# debugPrint(sortedIndices)



In [30]:
# cellIndices, cellInverse, cellCounters = torch.unique_consecutive(sortedIndices, return_counts=True, return_inverse=True)
# debugPrint(cellIndices)
# debugPrint(cellIndices.shape)
# debugPrint(cellCounters)



In [84]:
@torch.jit.script
def constructHashMap(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength):
    with record_function("hashmap"): 
        with record_function("hashmap - cell cumulation"): 
            cellIndices, cellCounters = torch.unique_consecutive(sortedIndices, return_counts=True, return_inverse=False)
            cumCell = torch.hstack((torch.tensor([0], device = cells.device, dtype=cellIndices.dtype),torch.cumsum(cellCounters,dim=0)))[:-1]
#             scanned = torch.cumsum(cellCounters, dim = 0)
            # debugPrint(cells)
            # debugPrint(cellCounters)
            # debugPrint(cellInverse)
            # debugPrint(scanned)

        with record_function('hashmap - compute indices'): 
            xIndices = cellIndices % cells[0]
            yIndices = torch.div(cellIndices, cells[0], rounding_mode='trunc')
            hashedIndices = (xIndices * 1212047 + yIndices * 15233249) % hashMapLength
#             hashCounters = torch.arange(hashedIndices.shape[0], device = hashedIndices.device)


        with record_function('hashmap - sort hashes'): 
            hashIndexSorting = torch.argsort(hashedIndices)
        # debugPrint(hashIndexSorting)

        # debugPrint(xIndices)
        # debugPrint(yIndices)
        # debugPrint(hashedIndices[hashIndexSorting])
        # debugPrint(hashCounters[hashIndexSorting])

        with record_function('hashmap - collision detection'): 
            hashMap, hashMapCounters = torch.unique_consecutive(hashedIndices[hashIndexSorting], return_counts=True, return_inverse=False)
            cellIndices = cellIndices[hashIndexSorting]
            cellSpan = cumCell[hashIndexSorting]
            cumCell = cellCounters[hashIndexSorting]

    #     debugPrint(hashMap)
    #     debugPrint(hashMapCounters)
    #     debugPrint(hashMapInverse)

        with record_function('hashmap - hashmap construction'):
            hashTable = hashMap.new_ones(hashMapLength,2) * -1
            hashTable[:,1] = 0
            hashTable[hashMap,0] = torch.arange(hashMap.shape[0], device=hashMap.device)
            hashTable[hashMap,1] = hashMapCounters

        #     debugPrint(hashTable)
#         with record_function('hashmap - hashcumulation'):
#             cumHash = torch.hstack((torch.tensor([0], device = cells.device, dtype=cellIndices.dtype),torch.cumsum(hashMapCounters,dim=0)))

        #     debugPrint(cumHash.dtype)
        #     debugPrint(hashMapInverse.shape)

    return hashTable, hashIndexSorting, cellIndices, cumCell, cellSpan

hashTable, hashIndexSorting, cellIndices, cumCell, cellSpan = constructHashMap(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength)



In [85]:
debugPrint(cellIndices)
debugPrint(cumCell)
debugPrint(cellSpan)

cellIndices [Tensor] = tensor([1487, 1560, 1260,  ...,  154, 1641, 1714])
cumCell [Tensor] = tensor([25, 25, 30,  ..., 25, 30, 25])
cellSpan [Tensor] = tensor([34245, 35955, 28785,  ...,  1635, 37950, 39890])




In [63]:
debugPrint(hashIndexSorting)

hashIndexSorting [Tensor] = tensor([1304, 1371, 1092,  ...,   55, 1449, 1516])




In [64]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],with_stack=True, profile_memory=True) as prof:    
    for i in range(128):
        hashTable, hashIndexSorting, cellIndices, cumCell, cellSpan = constructHashMap(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength)
#     for i in range(128):
#         neighs, distances, raddistances = neighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
#     for i in range(128):
#         neighs, distances, raddistances = jneighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
    
# print(prof.key_averages().table(sort_by='self_cpu_time_total'))
print(prof.key_averages().table(sort_by='cpu_time_total'))

----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                              Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  constructHashMap         0.86%     653.000us        99.98%      75.916ms     593.094us      37.67 Mb         512 b           128  
                           hashmap         1.27%     963.000us        96.84%      73.534ms     574.484us      37.67 Mb      -7.41 Kb           128  
    hashmap - hashmap construction         5.20%       3.950ms        24.83%      18.853ms     147.289us      28.22 Mb      -8.99 Kb           128  
         hashmap - cell cumulation         2.84%       2.159ms        20.14%      15.293ms     119.477us  



In [23]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],with_stack=True, profile_memory=True) as prof:    
    for i in range(128):
        hashMap, hashSpan, cellMap, cellSpan = constructHashMap2(qMin, hMax, cells, sortedPositions, sortedSupport, sortedIndices, sort, hashMapLength)
#     for i in range(128):
#         neighs, distances, raddistances = neighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
#     for i in range(128):
#         neighs, distances, raddistances = jneighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
    
# print(prof.key_averages().table(sort_by='self_cpu_time_total'))
print(prof.key_averages().table(sort_by='cpu_time_total'))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
            constructHashMap2        16.22%      13.897ms        99.99%      85.685ms     669.414us      35.78 Mb      -2.60 Mb           128  
                  aten::index        11.79%      10.103ms        25.64%      21.969ms      34.327us     -53.96 Mb     -67.20 Mb           640  
                 aten::arange        11.59%       9.931ms        22.62%      19.386ms      25.242us     106.44 Mb           0 b           768  
                aten::argsort         0.35%     304.000us        12.39%      10.614ms      82.922us       1.89 Mb    -529.38 Kb         



In [20]:
# connectivity = torch.zeros(queryParticles.shape[0],queryParticles.shape[0], dtype = torch.int64, layout = torch.sparse_coo, device = qP.device, requires_grad = False)

In [108]:
@torch.jit.script
def queryHashMap(qID, hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, cellSpan, sort, cells, hashMapLength):
    qLin = qID[0] + cells[0] * qID[1]
    qHash = (qID[0] * 1212047 +  qID[1] * 15233249)%  hashMapLength
    hashEntries = hashTable[qHash]
#     debugPrint(qLin)
#     debugPrint(qHash)
    if hashEntries[0] == -1:
        return torch.empty(0,0, dtype = torch.int64, device = qLin.device)
    hashIndices = hashEntries[0]
    minIter = hashEntries[0]
    maxIter = hashEntries[0] + hashEntries[1]
#     debugPrint(minIter)
#     debugPrint(maxIter)
    candidates = torch.empty(0,0, dtype = torch.int64, device = qLin.device)
#     return candidates
#     print(minIter.dtype)
    for i in range(int(minIter), int(maxIter)):
        hashIndex = i
        cellIndex = cellIndices[hashIndex]
#         debugPrint(i)
#         debugPrint(cellIndex)
        if cellIndex == qLin:
#             debugPrint(True)
            minCellIter = cellSpan[hashIndex]
            maxCellIter = cellSpan[hashIndex] + cumCell[hashIndex]        
#             debugPrint(minCellIter)
#             debugPrint(maxCellIter)
            for j in range(int(minCellIter), int(maxCellIter)):
                candidates = sort[minCellIter:maxCellIter]
            break
    return candidates



In [109]:

@torch.jit.script
def linearFrom2D(qID, cells):
    return qID[:,0] + cells[0] * qID[:,1]
@torch.jit.script
def hashFrom2D(qID, hashMapLength):
    return (qID[:,0] * 1212047 + qID[:,1] * 15233249) % hashMapLength

@torch.jit.script
def findNeighbors(queryParticles, support, qMin, hMax, cells, hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, cellSpan, sort, hashMapLength):
    qP = queryParticles

    qID = torch.ceil((qP - qMin) / hMax).to(torch.int64)
    qLinear = linearFrom2D(qID, cells)

#     debugPrint(cellIndices)

    cellIndexHelper = torch.arange(cellIndices.shape[0], device = cellIndices.device)

    overallRows = []
    overallCols = []

    for i in cellIndexHelper:
    #     debugPrint(i)
        cell = int(cellIndices[i])
#         debugPrint(cell)
    #     debugPrint(cumCell[i])
    #     debugPrint(cumCell[i+1])
#         minIter = cumCell[i]
#         maxIter = cumCell[i+1]

        qID = torch.tensor([cell % int(cells[0]), cell // int(cells[0])], dtype = torch.int64, device = qID.device)
    #     debugPrint(qID)
#         debugPrint(cell)
        indices = queryHashMap(qID, hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, cellSpan, sort, cells, hashMapLength)
    #     debugPrint(indices)
        centerPositions = queryParticles[indices]

        cumRows = []
        cumCols = []

        for ii in range(-1,2):
            for jj in range(-1,2):
                currentID = qID + torch.tensor([ii,jj],device=qID.device, dtype=qID.dtype)
                currentIndices = queryHashMap(currentID, hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, cellSpan, sort, cells, hashMapLength)
#                 debugPrint(currentIndices)
                if(currentIndices.shape[0] > 0):
                    currentPositions = queryParticles[currentIndices]
#                     debugPrint(indices)
#                     debugPrint(currentIndices)
    #                 debugPrint(currentIndices.repeat(1,indices.shape[0]))
                    rows = currentIndices.repeat_interleave(indices.shape[0])
                    cols = indices.repeat(1,currentIndices.shape[0])[0]

#                     debugPrint(rows)
#                     debugPrint(cols)
                
                    distances = queryParticles[rows] - queryParticles[cols]
                    distances = torch.linalg.norm(distances, dim = -1)

                    cumRows.append(rows[distances <= support])
                    cumCols.append(cols[distances <= support])
#                     debugPrint(cumRows)
#                     debugPrint(distances)
                    
#         break
    #                 debugPrint(filteredRows)
    #                 debugPrint(filteredCols)
    #                 mat = currentPositions[:,None] - centerPositions
    #                 mat = torch.linalg.norm(mat, axis = -1)
    #                 
    #                 debugPrint(mat.shape)

    #                 debugPrint(currentIndices)
    #                 stacked = torch.stack((rows, cols))
    #                 debugPrint(stacked)
    #                 spTensor = torch.sparse_coo_tensor(stacked, torch.ones(stacked.shape[0]), dtype=torch.float64)
    #                 debugPrint(spTensor)
#         debugPrint(cumRows)
        rows = torch.hstack(cumRows)
        cols = torch.hstack(cumCols)
        overallRows.append(rows)
        overallCols.append(cols)
    #     debugPrint(rows)
    #     debugPrint(cols)
    #     break

    rows = torch.hstack(overallRows)
    cols = torch.hstack(overallCols)
    
    return rows, cols

rows, cols = findNeighbors(queryParticles, support, qMin, hMax, cells, hashTable, None, hashIndexSorting, cellIndices, cumCell, cellSpan, sort, hashMapLength)
debugPrint(rows.shape)
debugPrint(cols.shape)
debugPrint(queryParticles.shape)
# for ptcl, ptclSupport in zip(queryParticles, querySupport):

#     debugPrint(ptcl)
#     debugPrint(ptclSupport)
#     qLin = linearFrom2D(qID, cells)
#     qHash = hashFrom2D(qID, hashMapLength)

# #     debugPrint(qID)
# #     debugPrint(qLin)
# #     debugPrint(qHash)

#     neighborCells = []
#     for i in range(-1,2):
#         for j in range(-1,2):
#             newID = qID + torch.tensor([i,j],device=qID.device, dtype=qID.dtype)
#             c = queryHashMap(linearFrom2D(newID, cells), hashFrom2D(newID, hashMapLength), hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, sort)
# #             if
# #             debugPrint(c)
#             if c.shape[0] > 0:
#                 neighborCells.append(c)
#     cN = torch.hstack(neighborCells)
# #     debugPrint(cN)
# #     debugPrint(cN.shape[0])
#     cC = torch.ones_like(cN) * qLin
# #     cC = torch.zeros(cN.shape, dtype = qID.dtype, device= qID.dtype)
# #     cC = torch.empty(0,0, dtype = torch.int64, device = qLin.device)
# #     cC[:] = qLin
#     return torch.stack((cC, cN))
    
    
#     break

rows.shape [Size] = torch.Size([1053245])
cols.shape [Size] = torch.Size([1053245])
queryParticles.shape [Size] = torch.Size([50625, 2])




In [102]:
rows,cols = radius(queryParticles, queryParticles, support,max_num_neighbors = 256)
debugPrint(rows.shape)

rows.shape [Size] = torch.Size([1053245])




In [110]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],with_stack=True, profile_memory=True) as prof:    
    for i in range(1):
        rows,cols = radius(queryParticles, queryParticles, support,max_num_neighbors = 256)
#         rows, cols = findNeighbors(queryParticles, support, qMin, hMax, cells, hashTable, cumHash, hashIndexSorting, cellIndices, cumCell, sort, hashMapLength)
#     for i in range(128):
#         neighs, distances, raddistances = neighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
#     for i in range(128):
#         neighs, distances, raddistances = jneighborSearch(queryPositions, queryPositions, support, maxNeighbors, gradientThreshold)
    
# print(prof.key_averages().table(sort_by='self_cpu_time_total'))
print(prof.key_averages().table(sort_by='cpu_time_total'))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   radius         0.05%      38.000us        99.96%      75.014ms      75.014ms      16.07 Mb           0 b             1  
    torch_cluster::radius        98.08%      73.602ms        99.91%      74.975ms      74.975ms      16.07 Mb         -16 b             1  
       aten::index_select         1.53%       1.148ms         1.54%       1.154ms       1.154ms      16.07 Mb      16.07 Mb             1  
              aten::empty         0.22%     167.000us         0.22%     167.000us      83.500us          16 b          16 b             2  
             aten::s