In [5]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from tqdm import TqdmExperimentalWarning
import warnings
warnings.filterwarnings(action='once')
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning) 
import torch

import numpy as np
import torchCompactRadius as tcr
from torchCompactRadius import radiusSearch, volumeToSupport
from torchCompactRadius.util import countUniqueEntries
from torchCompactRadius.radiusNaive import radiusNaive, radiusNaiveFixed
from torchCompactRadius.cppWrapper import neighborSearchSmall, neighborSearchSmallFixed
import platform
import pandas as pd
import time
from tqdm.autonotebook import tqdm
import copy
import seaborn as sns

from torch_cluster import radius as radius_cluster

In [2]:
def generateNeighborTestData(nx, targetNumNeighbors, dim, maxDomain_0, periodic, device):


    minDomain = torch.tensor([-1] * dim, dtype = torch.float32, device = device)
    maxDomain = torch.tensor([ 1] * dim, dtype = torch.float32, device = device)
    maxDomain[0] = maxDomain_0
    periodicity = [periodic] * dim

    extent = maxDomain - minDomain
    shortExtent = torch.min(extent, dim = 0)[0].item()
    dx = (shortExtent / nx)
    ny = int(1 // dx)
    h = volumeToSupport(dx**dim, targetNumNeighbors, dim)
    dy = dx

    # print(f"dx = {dx}, dy = {dy}, h = {h}")
    # print(f"nx = {nx}, ny = {ny}")
    # print(f"minDomain = {minDomain}, maxDomain = {maxDomain}")
    # print(f"periodicity = {periodicity}")
    # print(f"dim = {dim}")
    # print(f"device = {device}")
    # print(f"maxDomain_0 = {maxDomain_0}")
    # print(f"targetNumNeighbors = {targetNumNeighbors}")
    

    positions = []
    for d in range(dim):
        positions.append(torch.linspace(minDomain[d] + dx / 2, maxDomain[d] - dx / 2, int((extent[d] - dx) / dx) + 1, device = device))
    grid = torch.meshgrid(*positions, indexing = 'xy')
    positions = torch.stack(grid, dim = -1).reshape(-1,dim).to(device)
    supports = torch.ones(positions.shape[0], device = device) * h
    ypositions = []
    for d in range(dim):
        ypositions.append(torch.linspace(-0.5 + dy / 2, 0.5 - dy / 2, ny, device = device))
    grid = torch.meshgrid(*ypositions, indexing = 'xy')
    y = torch.stack(grid, dim = -1).reshape(-1,dim).to(device)
    ySupport = torch.ones(y.shape[0], device = device) * supports[0]
    return (y, positions), (ySupport, supports), (minDomain, maxDomain), periodicity, positions.shape[0]


In [4]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if platform.system() == 'Darwin':
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = torch.device('cpu')
targetNumNeighbors = 50
hashMapLength = 4096
nx = 32
dim = 2
periodic = False


In [6]:
(y, positions), (ySupport, supports), (minDomain, maxDomain), periodicity, hashMapLength = generateNeighborTestData(nx, targetNumNeighbors, dim, 1.0, False, device)

In [7]:
# def callRadiusNaive(y, positions, ySupport, supports, periodicity, minDomain, maxDomain, hashMapLength, mode, device, numTrials):
#     for i in range(numTrials):
#         r = radiusNaive(y, positions, ySupport, supports, periodicity, minDomain, maxDomain, mode)
#     return r

#     # i, j = neighborSearchSmall(y, ySupport, positions, supports, minDomain, maxDomain, periodicTensor, 'symmetric')

# def callRadiusSmall(y, positions, ySupport, supports, periodicity, minDomain, maxDomain, hashMapLength, mode, device, numTrials):
#     for i in range(numTrials):
#         r = neighborSearchSmall(y, ySupport, positions, supports, minDomain, maxDomain, torch.tensor(periodicity).to(device), mode)
#     return r

# def callRadiusCompact(y, positions, ySupport, supports, periodicity, minDomain, maxDomain, hashMapLength, mode, device, numTrials):
#     for i in range(numTrials):
#         r = neighborSearch((y, positions), (ySupport, supports), (minDomain, maxDomain), periodicity, hashMapLength, mode, 'cpp')
#     return r

# def callRadius(y, positions, ySupport, supports, periodicity, minDomain, maxDomain, hashMapLength, mode, device, numTrials):
#     for i in range(numTrials):
#         (i,j) = radius(y, positions, supports[0], max_num_neighbors = 384)
#     return (j,i)

In [7]:
(y, positions), (ySupport, supports), (minDomain, maxDomain), periodicity, hashMapLength = generateNeighborTestData(nx, targetNumNeighbors, dim, 1.0, False, device)
h = ySupport[0].cpu().item()

In [11]:
def test_ij(i, j, y, positions, periodic):
    if y.shape == positions.shape and torch.all(y == positions):
        if periodic:
            assert i.shape[0] == j.shape[0], f'i.shape[0] = {i.shape[0]} != j.shape[0] = {j.shape[0]}'
            assert i.shape[0] == 46080, f'i.shape[0] = {i.shape[0]} != 11520'
            assert j.shape[0] == 46080, f'i.shape[0] = {j.shape[0]} != 11520'
            ii, ni = countUniqueEntries(i, y)
            jj, nj = countUniqueEntries(j, positions)
            assert ni.min() == ni.max(), f'ni.min() = {ni.min()} != ni.max() = {ni.max()}'
            assert ni.min() == 45, f'ni.min() = {ni.min()} != 45'
            print('✅', end = '')
        else:
            assert i.shape[0] == j.shape[0], f'i.shape[0] = {i.shape[0]} != j.shape[0] = {j.shape[0]}'
            assert i.shape[0] == 41580, f'i.shape[0] = {i.shape[0]} != 41580'
            assert j.shape[0] == 41580, f'i.shape[0] = {j.shape[0]} != 41580'
            ii, ni = countUniqueEntries(i, y)
            jj, nj = countUniqueEntries(j, positions)
            assert ni.min() != ni.max(), f'ni.min() = {ni.min()} == ni.max() = {ni.max()}'
            assert nj.min() != nj.max(), f'nj.min() = {nj.min()} == nj.max() = {nj.max()}'

            assert ni.min() == 15, f'ni.min() = {ni.min()} != 15'
            assert ni.max() == 45, f'ni.min() = {ni.min()} != 45'
            print('✅', end = '')
    else:
        assert i.shape[0] == j.shape[0], f'i.shape[0] = {i.shape[0]} != j.shape[0] = {j.shape[0]}'
        assert i.shape[0] == 11520, f'i.shape[0] = {i.shape[0]} != 11520'
        assert j.shape[0] == 11520, f'i.shape[0] = {j.shape[0]} != 11520'
        ii, ni = countUniqueEntries(i, y)
        jj, nj = countUniqueEntries(j, positions)
        assert ni.min() == ni.max(), f'ni.min() = {ni.min()} != ni.max() = {ni.max()}'
        assert ni.min() == 45, f'ni.min() = {ni.min()} != 45'
        print('✅', end = '')


In [14]:
periodic = True
reducedSet = True
algorithm = 'naive'

for periodic in [True, False]:
    for reducedSet in [True, False]:
        for algorithm in ['naive', 'small', 'compact']:
            print(f'periodic = {periodic}, \treducedSet = {reducedSet}, \talgorithm = {algorithm}\t', end = '')
            i, j = radiusSearch(y if reducedSet else positions, positions, h, algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain)
            test_ij(i, j, y if reducedSet else positions, positions, periodic)
            
            i, j = radiusSearch(y if reducedSet else positions, positions, ySupport if reducedSet else supports, algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain)
            test_ij(i, j, y if reducedSet else positions, positions, periodic)

            i, j = radiusSearch(y if reducedSet else positions, positions, (ySupport, supports) if reducedSet else (supports, supports), algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain)
            test_ij(i, j, y if reducedSet else positions, positions, periodic)

            i, j = radiusSearch(y if reducedSet else positions, positions, (ySupport, supports) if reducedSet else (supports, supports), algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain, mode = 'scatter')
            i, j = radiusSearch(y if reducedSet else positions, positions, (ySupport, supports) if reducedSet else (supports, supports), algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain, mode = 'gather')
            i, j = radiusSearch(y if reducedSet else positions, positions, (ySupport, supports) if reducedSet else (supports, supports), algorithm = algorithm, periodicity = periodic, domainMin = minDomain, domainMax = maxDomain, mode = 'symmetric')
            test_ij(i, j, y if reducedSet else positions, positions, periodic)
            print('')

periodic = True, 	reducedSet = True, 	algorithm = naive	✅✅✅✅
periodic = True, 	reducedSet = True, 	algorithm = small	✅✅✅✅
periodic = True, 	reducedSet = True, 	algorithm = compact	✅✅✅✅
periodic = True, 	reducedSet = False, 	algorithm = naive	✅✅✅✅
periodic = True, 	reducedSet = False, 	algorithm = small	✅✅✅✅
periodic = True, 	reducedSet = False, 	algorithm = compact	✅✅✅✅
periodic = False, 	reducedSet = True, 	algorithm = naive	✅✅✅✅
periodic = False, 	reducedSet = True, 	algorithm = small	✅✅✅✅
periodic = False, 	reducedSet = True, 	algorithm = compact	✅✅✅✅
periodic = False, 	reducedSet = False, 	algorithm = naive	✅✅✅✅
periodic = False, 	reducedSet = False, 	algorithm = small	✅✅✅✅
periodic = False, 	reducedSet = False, 	algorithm = compact	✅✅✅✅
