In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class Latent(nn.Module):
    def __init__(self, n_latents, z_dim, **kwargs):
        super().__init__()
        self.L = n_latents
        self.z_dim = z_dim
        # (1, L, 2, z)
        self.e = nn.Parameter(torch.randn(1, n_latents, 2, z_dim))
                
    def forward(self, data, **kwargs):
        # data['z'] : (N, Lz, H, W)
        
        T = kwargs['latent_temp'] if 'latent_temp' in kwargs else 1.0
        
        # (1, L, 2, z)
        e = self.e
        z = data['z']
        N, _, H, W = z.size()
        L = self.L
        z_dim = self.z_dim
        # (NHW, L, 1, z)
        z = z.permute(0, 2, 3, 1).reshape(-1, L, 1, z_dim)
        
        # (NHW, L, 2, 1) = (1, L, 2, 1) - 2*(NHW, L, 2, 1) + (NHW, L, 1, 1)
        distance = (e**2).sum(3, keepdim=True) -\
                   2*e@z.transpose(2, 3) +\
                   (z.transpose(2, 3)**2).sum(2, keepdim=True)
        # (NHW, L, 2, 1)
        ratio = torch.softmax(-distance/T, dim=2)
        # (NHW, L, z)
        zq = torch.sum(self.e * ratio, dim=2)
        # (N, Lz, H, W)
        zq = zq.reshape(N, H, W, L*z_dim).permute(0, 3, 1, 2)
        data['z'] = zq
        
        return data

In [20]:
N = 16
H = 32
W = 32
L = 24
z_dim = 16
data = {'z': torch.randn(N, L*z_dim, H, W)}
latent = Latent(n_latents=L, z_dim=z_dim)
data = latent(data)
print(data['z'].shape)

torch.Size([16, 384, 32, 32])
