In [3]:
%load_ext autoreload
%autoreload 2

In [52]:
import collections
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import scipy.spatial

import voxart

In [10]:
import cProfile
import pstats
from pstats import SortKey

# 0.profile

In [12]:
goal = voxart.Goal.from_image(Image.open("../assets/chiral_7.png"))

In [13]:
opts = voxart.SearchOptions()
opts.filled_num_batches = 2
cProfile.run('results = voxart.search(goal, opts)', '0.profile')

SearchOptions(filled_batch_size=3, filled_num_batches=2, filled_strategy='random_clear_front', filled_num_to_pursue=20, connector_num_iterations_per=30, connector_frac=0.4, top_n=20, obj_func=None)


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [15]:
p = pstats.Stats('profile0')

In [16]:
p.strip_dirs().sort_stats(SortKey.CUMULATIVE).print_stats()

Thu Mar 23 19:36:46 2023    profile0

         117476310 function calls (117398762 primitive calls) in 100.559 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      9/1    0.000    0.000  100.573  100.573 {built-in method builtins.exec}
        1    0.000    0.000  100.573  100.573 <string>:1(<module>)
        1    0.015    0.015  100.573  100.573 search.py:558(search)
       40    0.287    0.007   58.386    1.460 search.py:360(search_connectors)
     7200   15.017    0.002   57.818    0.008 search.py:302(get_shortest_path_to_targets)
     9392    0.033    0.000   41.376    0.004 search.py:160(add)
     9392    0.219    0.000   41.318    0.004 search.py:107(__call__)
     7680    1.408    0.000   40.974    0.005 search.py:493(count_unsupported)
   620736   17.740    0.000   39.269    0.000 search.py:479(is_vox_unsupported)
  8682249   10.020    0.000   37.668    0.000 search.py:282(get_neighbors)
8499463/8498951    4.148 

<pstats.Stats at 0x7fa52ca2a340>

In [17]:
p.strip_dirs().sort_stats(SortKey.TIME).print_stats()

Thu Mar 23 19:36:46 2023    profile0

         117476310 function calls (117398762 primitive calls) in 100.559 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   620736   17.740    0.000   39.269    0.000 search.py:479(is_vox_unsupported)
  6930261   16.316    0.000   16.316    0.000 {built-in method numpy.array}
     7200   15.017    0.002   57.818    0.008 search.py:302(get_shortest_path_to_targets)
  8682249   10.020    0.000   37.668    0.000 search.py:282(get_neighbors)
 30129674    6.633    0.000    6.633    0.000 <string>:2(__lt__)
  1583338    4.590    0.000    4.590    0.000 {method 'reduce' of 'numpy.ufunc' objects}
8499463/8498951    4.148    0.000   31.472    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
  6930261    4.002    0.000   26.157    0.000 <__array_function__ internals>:177(copy)
  2808440    3.180    0.000    7.603    0.000 {built-in method _heapq.heappop}
  5351441    

<pstats.Stats at 0x7fa52ca2a340>

# get_neighbors

get_neighbors is suprisingly slow

Let's investigate

In [19]:
def get_neighbors_orig(vox, size):
    vox = np.asarray(vox)
    if vox.shape != (3,):
        raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    for axis, delta in itertools.product([0, 1, 2], [-1, 1]):
        newval = vox[axis] + delta
        if newval < 0 or newval >= size:
            continue
        neighbor = np.copy(vox)
        neighbor[axis] = newval
        yield neighbor


In [53]:
%timeit collections.deque(get_neighbors_orig(np.array([1, 2, 3]), 4), maxlen=0)

8.06 µs ± 272 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [26]:
def get_neighbors_1(vox, size):
    #vox = np.asarray(vox)
    #if vox.shape != (3,):
    #    raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    for axis, delta in itertools.product([0, 1, 2], [-1, 1]):
        newval = vox[axis] + delta
        if newval < 0 or newval >= size:
            continue
        neighbor = np.copy(vox)
        neighbor[axis] = newval
        yield neighbor


In [54]:
%timeit collections.deque(get_neighbors_1(np.array([1, 2, 3]), 4), maxlen=0)

7.81 µs ± 203 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [28]:
DELTAS = list(itertools.product([0, 1, 2], [-1, 1]))
def get_neighbors_2(vox, size):
    #vox = np.asarray(vox)
    #if vox.shape != (3,):
    #    raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    for axis, delta in DELTAS:
        newval = vox[axis] + delta
        if newval < 0 or newval >= size:
            continue
        neighbor = np.copy(vox)
        neighbor[axis] = newval
        yield neighbor

In [55]:
%timeit collections.deque(get_neighbors_2(np.array([1, 2, 3]), 4), maxlen=0)

7.49 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [84]:
def get_neighbors_3(vox, size):
    vox = np.array(vox)
    if vox.shape != (3,):
        raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    for axis, delta in itertools.product([0, 1, 2], [-1, 1]):
        newval = vox[axis] + delta
        if newval < 0 or newval >= size:
            continue
        #neighbor = np.copy(vox)
        vox[axis] = newval
        yield vox
        vox[axis] -= delta


In [85]:
list(get_neighbors_3(np.array([1, 2, 3]), 4))

[array([1, 2, 3]),
 array([1, 2, 3]),
 array([1, 2, 3]),
 array([1, 2, 3]),
 array([1, 2, 3])]

In [86]:
np.fromiter(get_neighbors_3(np.array([1, 2, 3]), 4), dtype=(int, 3))

array([[0, 2, 3],
       [2, 2, 3],
       [1, 1, 3],
       [1, 3, 3],
       [1, 2, 2]])

In [87]:
orig_vox = np.array([1, 2, 3])
for neigh in get_neighbors_3(orig_vox, 4):
    print(orig_vox, neigh)

[1 2 3] [0 2 3]
[1 2 3] [2 2 3]
[1 2 3] [1 1 3]
[1 2 3] [1 3 3]
[1 2 3] [1 2 2]


In [88]:
%timeit collections.deque(get_neighbors_3(np.array([1, 2, 3]), 4), maxlen=0)

4.27 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [70]:
DELTA_MATRIX = np.array([
    [-1, 0, 0],
    [1, 0, 0],
    [0, -1, 0],
    [0, 1, 0],
    [0, 0, -1],
    [0, 0, 1],
    ])
def get_neighbors_4(vox, size):
    all_neighs = vox + DELTA_MATRIX
    #display(all_neighs)
    #vox = np.asarray(vox)
    #if vox.shape != (3,):
    #    raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    invalid = np.logical_or.reduce((all_neighs < 0) | (all_neighs >= size), axis=1)
    #display(invalid)
    for neigh_idx in range(6):
        #print(invalid[neigh_idx])
        if invalid[neigh_idx]:
            continue
        yield all_neighs[neigh_idx]

In [71]:
list(get_neighbors_4(np.array([1, 2, 3]), 4))

[array([0, 2, 3]),
 array([2, 2, 3]),
 array([1, 1, 3]),
 array([1, 3, 3]),
 array([1, 2, 2])]

In [72]:
%timeit collections.deque(list(get_neighbors_4(np.array([1, 2, 3]), 4)), maxlen=0)

9.01 µs ± 79.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [62]:
DELTA_MATRIX = np.array([
    [-1, 0, 0],
    [1, 0, 0],
    [0, -1, 0],
    [0, 1, 0],
    [0, 0, -1],
    [0, 0, 1],
    ])
def get_neighbors_5(vox, size):
    all_neighs = vox + DELTA_MATRIX
    #display(all_neighs)
    #vox = np.asarray(vox)
    #if vox.shape != (3,):
    #    raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    #invalid = np.logical_or.reduce((all_neighs < 0) | (all_neighs >= size), axis=1)
    #display(invalid)
    for neigh_idx in range(6):
        #print(invalid[neigh_idx])
        #if invalid[neigh_idx]:
        #    continue
        yield all_neighs[neigh_idx, :]

In [63]:
%timeit collections.deque(list(get_neighbors_5(np.array([1, 2, 3]), 4)), maxlen=0)

4.78 µs ± 217 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [73]:
DELTA_MATRIX = np.array([
    [-1, 0, 0],
    [1, 0, 0],
    [0, -1, 0],
    [0, 1, 0],
    [0, 0, -1],
    [0, 0, 1],
    ])
def get_neighbors_6(vox, size):
    all_neighs = vox + DELTA_MATRIX
    #display(all_neighs)
    #vox = np.asarray(vox)
    #if vox.shape != (3,):
    #    raise ValueError(f"Only suport 3D neightbors, got shape {vox.shape}")
    #invalid = np.logical_or.reduce((all_neighs < 0) | (all_neighs >= size), axis=1)
    #display(invalid)
    for neigh_idx in range(6):
        neigh = all_neighs[neigh_idx]
        if np.any(neigh < 0) or np.any(neigh >= size):
            continue
        #print(invalid[neigh_idx])
        #if invalid[neigh_idx]:
        #    continue
        yield neigh

In [74]:
%timeit collections.deque(list(get_neighbors_6(np.array([1, 2, 3]), 4)), maxlen=0)

71.5 µs ± 8.68 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Based on all this, I am switching to the form in get_neighbors_3