A notebook for development and testing of a memory constrained knn function

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.cuda import memory_stats

from janelia_core.ml.torch_fcns import knn_mc
from janelia_core.ml.utils import list_torch_devices
from janelia_core.ml.utils import torch_devices_memory_usage

## Parameters go here 

In [3]:
n_centers = 1050
n_data_pts = 130000
d = 3

## Randomly generate centers and data

In [4]:
device = list_torch_devices()[0][0]

Found 1 GPUs


In [5]:
x = torch.randn([n_data_pts, d], device=device)
ctrs = torch.randn([n_centers, d], device=device)

In [10]:
def knn(x, ctrs, k):
    """ Non memory efficient implementation - just for comparison. """
    d = ctrs.shape[1]
    n_ctrs = ctrs.shape[0]
    
    diffs = x - torch.reshape(ctrs, [n_ctrs, 1, d])
    sq_distances = torch.sum(diffs**2, dim=2)
    
    return torch.topk(-sq_distances, k=k, dim=0).indices

In [11]:
#nn0 = knn(x=x, ctrs=ctrs, k=3)

In [12]:
nn = knn_mc(x=x, ctrs=ctrs, k=3, m=10)

## Print max memory usage

In [13]:
max_memory_used = torch_devices_memory_usage([device], 'max_memory_allocated')
print('Max memory used: ' + "{:e}".format(max_memory_used[0]))


exp_max_memory = 2*n_centers*n_data_pts*d*4
print('Expected max memory for naive implementation: ' + "{:e}".format(exp_max_memory))

Max memory used: 7.315661e+07
Expected max memory for naive implementation: 3.276000e+09
