In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Quantizer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, c)
        # data['e'] : (M, c)
        
        z = data['z']
        e = data['e']
        
        # (N, M)
        distances = torch.cdist(z.unsqueeze(0), e.unsqueeze(0)).squeeze(0)
        # (N,)
        min_indices = torch.argmin(distances, dim=1)
        # (N, c)
        z_q = torch.index_select(e, 0, min_indices)
        
        data['z'] = z + (z_q - z).detach()
        
        return data

In [3]:
data = {'z': torch.randn(128, 512),
        'e': torch.randn(8192, 512)}
data = Quantizer()(data)
print(data['z'].shape, data['e'].shape)

torch.Size([128, 512]) torch.Size([8192, 512])
