In [2]:
import torch
from torch.autograd import Variable
from cdae.distributions import categorical_sample, categorical_logpdf
import argparse
opt = argparse.Namespace( # Fake parsed arguments
    seed=3,
    device=-1,
    cuda=False,
    visdom=True
)

import cdae.util as util
util.init(opt)
import numpy as np

In [3]:
def gumbel_sample(location, scale):
    """
    Returns a Tensor of samples from Gumbel(location, scale).
    
    input:
        location: Tensor [dim_1, ..., dim_N]
        scale: Tensor [dim_1, ..., dim_N]
    
    output: Tensor [dim_1, ..., dim_N]
    """
    
    return location - scale * torch.log(-torch.log(torch.rand(location.size())))

def gumbel_logpdf(value, location, scale):
    """
    Returns Gumbel logpdfs.
    
    input:
        value: Tensor/Variable [dim_1, ..., dim_N]
        location: Tensor/Variable [dim_1, ..., dim_N]
        scale: Tensor/Variable [dim_1, ..., dim_N]
    
    output: Tensor/Variable [dim_1, ..., dim_N]
    """
    
    temp = (value - location) / scale
    
    return -(temp + torch.exp(-temp)) - torch.log(scale)

def concrete_sample(location, temperature):
    """
    Returns a Tensor of samples from Concrete(location, temperature).
    
    input:
        location: Tensor [num_categories, dim_1, ..., dim_N] (or [num_categories])
        temperature: Tensor [dim_1, ..., dim_N] (or int/float/[1])
    
    output: Tensor [num_categories, dim_1, ..., dim_N] (or [num_categories])
    """
    
    if location.ndimension() == 1:
        if isinstance(temperature, (int, float)):
            temperature = torch.Tensor([temperature])
        temperature_expanded = temperature.expand_as(location)
    else:
        temperature_expanded = temperature.unsqueeze(0).expand_as(location)
    gumbels = gumbel_sample(torch.zeros(location.size()), torch.ones(location.size()))
    
    numerator = torch.exp((torch.log(location) + gumbels) / temperature_expanded)
    denominator = torch.sum(numerator, dim=0).expand_as(numerator)
    
    return numerator / denominator

def concrete_logpdf(value, location, temperature):
    """
    Returns a Tensor of Concrete logpdfs.
    
    input:
        value: Tensor/Variable [num_categories, dim_1, ..., dim_N] (or [num_categories])
        location: Tensor/Variable [num_categories, dim_1, ..., dim_N] (or [num_categories])
        temperature: Tensor/Variable [dim_1, ..., dim_N] (or int/float/[1])
    output: Tensor/Variable [dim_1, ..., dim_N] (or [1])
    """
    
    num_categories, *_ = value.size()

    if location.ndimension() == 1:
        if isinstance(temperature, (int, float)):
            if isinstance(location, Variable):
                temperature = Variable(torch.Tensor([temperature]))
            else:
                temperature = torch.Tensor([temperature])
        temperature_expanded = temperature.expand_as(location)
    else:
        temperature_expanded = temperature.unsqueeze(0).expand_as(location)
        
    return torch.sum(torch.arange(1, num_categories)) + \
        (num_categories - 1) * torch.log(temperature) + \
        torch.sum(torch.log(location) - (temperature_expanded + 1) * torch.log(value), dim=0).squeeze(0) - \
        num_categories * torch.log(torch.sum(location * (value**(-temperature_expanded)), dim=0).squeeze(0))
        
        
def discrete_sample(probabilities):
    """
    Returns a Tensor of samples from a Discrete(probabilities).

    input:
        probabilities: Tensor [num_categories, dim_1, ..., dim_N] (or [num_categories])

    output: Tensor [dim_1, ..., dim_N] (or [1])
    """
    
    num_categories = probabilities.size(0)
    categories = torch.arange(0, num_categories)
    for n in range(probabilities.ndimension() - 1):
        categories = categories.unsqueeze(-1)
    categories = categories.expand_as(probabilities)
    
    return categorical_sample(categories, probabilities)


def discrete_logpdf(value, probabilities):
    """
    Returns Discrete logpdfs.
    
    input:
        value: Tensor/Variable [dim_1, ..., dim_N] (or [1])
        probabilities: Tensor/Variable [num_categories, dim_1, ..., dim_N] (or [num_categories])
    
    output: Tensor/Variable [dim_1, ..., dim_N] (or [1])
    """
    
    num_categories = probabilities.size(0)
    categories = torch.arange(0, num_categories)
    for n in range(probabilities.ndimension() - 1):
        categories = categories.unsqueeze(-1)
    categories = categories.expand_as(probabilities)
    if isinstance(probabilities, Variable):
        categories = Variable(categories)

    return categorical_logpdf(value, categories, probabilities)

def one_hot_discrete_sample(probabilities):
    """
    Returns a Tensor of samples from a Discrete(probabilities) in a one-hot form.

    input:
        probabilities: Tensor [num_categories, dim_1, ..., dim_N] (or [num_categories])

    output: Tensor [num_categories, dim_1, ..., dim_N] (or [num_categories])
    """
    
    output = torch.zeros(probabilities.size())
    d = discrete_sample(probabilities)
    
    if probabilities.ndimension() == 1:
        return output.scatter_(0, d.long(), 1)
    else:
        return output.scatter_(0, d.long().unsqueeze(0), 1)
    
def one_hot_discrete_logpdf(value, probabilities):
    """
    Returns logpdfs of one-hot valued Discrete.
    
    input:
        value: Tensor/Variable [num_categories, dim_1, ..., dim_N] (or [num_categories])
        probabilities: Tensor/Variable [num_categories, dim_1, ..., dim_N] (or [num_categories])
        
    output: Tensor/Variable [dim_1, ..., dim_N] (or [1])
    """
    
    return torch.log(torch.sum(value * probabilities, dim=0)).squeeze(0)

In [11]:
uniforms = torch.rand(3)
probabilities = uniforms / torch.sum(uniforms)
categories = torch.rand(3)

In [18]:
categorical_sample(categories, probabilities).size() == torch.Size([1])

True