In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
hparams = {
    'n_features' : 10000,
    'hidden_dim' : 128
}

In [145]:
def straight_through_estimator(logits):
    argmax = torch.eq(logits, logits.max(-1, keepdim=True).values).to(logits.dtype)
    return (argmax - logits).detach() + logits

def gumbel_softmax(logits, temperature=1.0, eps=1e-20):
    u = torch.rand(logits.size(), dtype=logits.dtype, device=logits.device)
    g = -torch.log(-torch.log(u + eps) + eps)
    return F.softmax((logits + g) / temperature, dim=-1)

class CategoricalLayer(nn.Module):
    def __init__(self, input_dim, categorical_dim, output_dim=None):
        super().__init__()
        
        if output_dim == None:
            output_dim = input_dim
            
        self.dense_in = nn.Linear(input_dim, categorical_dim, bias=True)
        self.dense_out = nn.Linear(input_dim+categorical_dim, output_dim, bias=True)
        
    def forward(self, inputs, straight_through=True, sample=False, temperature=1.0):
        logits = self.dense_in(inputs)
        
        if sample:
            dist = gumbel_softmax(logits, temperature=temperature)
        else:
            dist = F.softmax(logits, dim=-1)
            
        if straight_through:
            dist = straight_through_estimator(dist)
            
        h = self.dense_out(torch.cat([inputs, dist], dim=-1))
        return h, dist

In [3]:
inputs = torch.rand(8, hparams['n_features'])

input_dense = nn.Linear(hparams['n_features'], hparams['hidden_dim'], bias=True)

