In [6]:
from dciknn_cuda import DCI

In [8]:
import torch



In [11]:
# dci_db = DCI(128,2,10,100,10)

# dci_db.add(data)
# index, _ = dci_db.query(query)
# print(index[0][0].long())
# dci_db.clear()

In [12]:
import lightning as L

def find_nearest_neighbor(x, y):
    # Compute the squared Euclidean distances between y and each point in x
    distances = ((x - y) ** 2).sum(1)
    # Find the index of the nearest neighbor
    min_index = distances.argmin()
    return min_index

class DCI_Helper:
    def __init__(self):
        self.dci_db = DCI(128,2,10,100,10)
        
    def __call__(self, x, y):
        # print(x.shape, x.device, x.dtype)
        # print(y.shape, y.device, y.dtype)
        
        self.dci_db.add(x)
        index, _ = self.dci_db.query(y)
        val = index[0][0].long().item()
        self.dci_db.clear()
        return val
    
    def alt(self, x, y):
        return find_nearest_neighbor(x, y)

In [26]:
class LitIMLEGenerator(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.dci_db = DCI_Helper()
    
    
    def find_closest_latent(self, generated_pc_latent, real_pc_latent):
        selected_generated_pc_latent = []
        
        n_samples = generated_pc_latent.shape[0]
        
        for i in range(n_samples):
            generated_pc_latent_i = generated_pc_latent[i] # [num_latent, latent_dim] 
            real_pc_latent_i = real_pc_latent[i].unsqueeze(0) # [1, latent_dim]
            
            index = self.dci_db(generated_pc_latent_i, real_pc_latent_i)
            
            selected_generated_pc_latent.append(generated_pc_latent_i[index])
            
        selected_generated_pc_latent = torch.stack(selected_generated_pc_latent)
        
        torch.cuda.empty_cache()
        return selected_generated_pc_latent
    
    def forward(self):
        num_latents = 80
        batch_size = 32
        dimension = 128

        generated_latents = torch.randn(batch_size, num_latents, dimension).cuda()
        real_latent = torch.randn(batch_size, dimension).cuda()
        
        return self.find_closest_latent(generated_latents, real_latent)
    

In [22]:
num_latents = 80
batch_size = 32
dimension = 128

generated_latents = torch.randn(batch_size, num_latents, dimension).cuda()
real_latent = torch.randn(batch_size, dimension).cuda()

In [27]:
model = LitIMLEGenerator()

In [25]:
model.find_closest_latent(generated_latents, real_latent)

tensor([[-1.2044, -0.0964,  0.7220,  ..., -1.0811, -0.0609,  0.4459],
        [ 1.2331, -0.2506,  1.0042,  ...,  0.8392, -1.2361, -0.4691],
        [ 0.1467,  0.4821, -1.6091,  ...,  0.9150, -0.1763, -1.6714],
        ...,
        [-0.7575,  0.4880,  0.2686,  ...,  1.0795,  0.2781, -0.1403],
        [ 0.4398,  1.0322,  1.3751,  ..., -0.0102, -1.7053,  0.9193],
        [ 1.7316,  0.6041,  1.0555,  ..., -0.8117, -0.9223, -1.2265]],
       device='cuda:0')

In [28]:
model.find_closest_latent(generated_latents, real_latent)

tensor([[-1.2044, -0.0964,  0.7220,  ..., -1.0811, -0.0609,  0.4459],
        [ 1.2331, -0.2506,  1.0042,  ...,  0.8392, -1.2361, -0.4691],
        [ 0.1467,  0.4821, -1.6091,  ...,  0.9150, -0.1763, -1.6714],
        ...,
        [-0.7575,  0.4880,  0.2686,  ...,  1.0795,  0.2781, -0.1403],
        [ 0.4398,  1.0322,  1.3751,  ..., -0.0102, -1.7053,  0.9193],
        [ 1.7316,  0.6041,  1.0555,  ..., -0.8117, -0.9223, -1.2265]],
       device='cuda:0')