In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
hparams = {
    'n_features' : 10000,
    'hidden_dim' : 128,
    'latent_dims' : [10, 20, 30]
}

In [9]:
def straight_through_estimator(logits):
    argmax = torch.eq(logits, logits.max(-1, keepdim=True).values).to(logits.dtype)
    return (argmax - logits).detach() + logits

def gumbel_softmax(logits, temperature=1.0, eps=1e-20):
    u = torch.rand(logits.size(), dtype=logits.dtype, device=logits.device)
    g = -torch.log(-torch.log(u + eps) + eps)
    return F.softmax((logits + g) / temperature, dim=-1)

class CategoricalLayer(nn.Module):
    def __init__(self, input_dim, categorical_dim, output_dim=None):
        super().__init__()
        
        if output_dim == None:
            output_dim = input_dim
            
        self.dense_in = nn.Linear(input_dim, categorical_dim, bias=True)
        self.dense_out = nn.Linear(input_dim+categorical_dim, output_dim, bias=True)
        
    def forward(self, inputs, straight_through=True, sample=False, temperature=1.0):
        logits = self.dense_in(inputs)
        
        if sample:
            dist = gumbel_softmax(logits, temperature=temperature)
        else:
            dist = F.softmax(logits, dim=-1)
            
        if straight_through:
            dist = straight_through_estimator(dist)
            
        h = torch.tanh(self.dense_out(torch.cat([inputs, dist], dim=-1)))
        return h, dist
    
class HLGC(nn.Module):
    def __init__(self, n_classes, input_dim, categorical_dims, hidden_dim=128):
        super().__init__()
        
        self.input_dense = nn.Linear(input_dim, hidden_dim, bias=True)
        
        self.categorical_layers =

In [11]:
inputs = torch.rand(8, hparams['n_features'])

input_dense = nn.Linear(hparams['n_features'], hparams['hidden_dim'], bias=True)

latent_layers = nn.ModuleList([
    CategoricalLayer(hparams['hidden_dim'], dim) for dim in hparams['latent_dims']
])

global_dense = nn.Linear(sum(hparams['latent_dims']), hparams['hidden_dim'], bias=True)
out_dense = nn.Linear(hparams['hidden_dim'], 20, bias=True)

In [13]:
h = input_dense(inputs)
print(h.size())

dists = []
for layer in latent_layers:
    h, dist = layer(h)
    dists.append(dist)
    print(dist.size())

h = torch.tanh(global_dense(torch.cat(dists, dim=-1)))
print(h.size())
logits = out_dense(h)
print(logits.size())

torch.Size([8, 128])
torch.Size([8, 10])
torch.Size([8, 20])
torch.Size([8, 30])
torch.Size([8, 128])
torch.Size([8, 20])
