In [5]:
import numpy as np
import torch.nn as nn
from s2cnn.nn.soft.so3_conv import SO3Convolution
from s2cnn.nn.soft.s2_conv import S2Convolution
from s2cnn.nn.soft.so3_integrate import so3_integrate
from s2cnn.ops.so3_localft import near_identity_grid as so3_near_identity_grid
from s2cnn.ops.s2_localft import near_identity_grid as s2_near_identity_grid
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip, pickle
import numpy as np
from torch.autograd import Variable
from torch.distributions import Normal

In [11]:
class MLP(nn.Module):
    def __init__(self ,H = [1,10,1], activation = nn.ReLU):
        super(MLP, self).__init__()
        self.H = H
        self.activation = activation()
        modules = []
        for input_dim, output_dim in zip (H, H[1:-1]):
            modules.append(nn.Linear(input_dim, output_dim))
            modules.append(self.activation)
        modules.append(nn.Linear(H[-2],H[-1]))
        self.module = nn.Sequential(*modules)    

    def forward(self, x):
        y = self.module(x)
        return y
    
class Nreparametrize(nn.Module):
    def __init__(self,input_dim, z_dim):
        super(Nreparametrize, self).__init__()
            
        self.input_dim = input_dim
        self.z_dim = z_dim
        
        self.sigma_linear = nn.Linear(input_dim, z_dim)
        self.mu_linear = nn.Linear(input_dim, z_dim)
        
    def forward(self, x, n = 1):
        x = F.relu(x)
        self.mu = self.mu_linear(x)
        self.log_sigma = self.sigma_linear(x)  
        self.z = self.nsample(self.mu, self.log_sigma, n = n)
        
        return self.z
    
    
    def kl(self):
        return -0.5 * torch.sum(1 + 2*self.log_sigma - self.mu.pow(2) - self.log_sigma.exp()**2, -1)
        
        
    def log_posterior(self):
        raise Normal(self.mu, self.log_sigma.exp()).log_prob(self.z)
        
    def log_prior(self):
        return Normal(torch.ones_like(self.mu), torch.zeros_like(self.log_sigma)).log_prob(self.z)
        
        
    @staticmethod     
    def nsample(mu, log_sigma, n = 1):
        eps = Normal(torch.ones_like(mu), torch.zeros_like(mu)).sample_n(n)
        return mu + eps*log_sigma.exp()           

In [51]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = None 
        self.decoder = None 
        self.reparametrize = [None]
        
        
    def encode(self,x, n = 1):
        h = self.encoder(x)
        self.z = [r(h, n) for r in self.reparametrize]
        return self.z
    
    def kl(self):
        # NOTE always call after encode
        # TODO make this bahaviour impossible
        kl = [r.kl() for r in self.reparametrize]
        return kl
        
    def decode(self, z):
        return self.decoder(z) 
    
    def forward(self, x):
        self.encode(x);
    
        # flatten and stack z
        d0, d1 = self.z[0].size()[:2]
        z_cat = torch.cat([v.view(d0,d1,-1) for v in self.z], -1)
        
        return self.decode(z_cat)
    
    @staticmethod
    def recon_loss(x, x_recon):
        raise NotImplemented 
    
    def elbo(self,x):
        x_recon = self.forward(x)
        kl_summed  = torch.sum(torch.stack(kl,-1), -1)
        recon_loss = self.recon_loss(x, x_recon)
        return recon_loss, kl_summed
        
        
    def log_likelihood(self, x):
        raise NotImplemented       

class NVAE(VAE):
    def __init__(self, encoder_dims = [10, 7, 5], decoder_dims = [5,7,10]):
        super(NVAE, self).__init__() 
        self.encoder_dims = encoder_dims
        self.decoder_dims = decoder_dims
        assert encoder_dims[-1] == decoder_dims[0]
        self.encoder = MLP(self.encoder_dims[:-1])
        self.decoder = MLP(self.decoder_dims)
        self.reparametrize = [Nreparametrize(self.encoder_dims[-2], self.encoder_dims[-1])]

In [49]:
vae = NVAE()

In [44]:
x = Variable(Normal(torch.Tensor(np.zeros((10))), torch.Tensor(np.ones((10)))).sample_n(6))
rec = vae(x)    

In [39]:
torch.sum(torch.stack(kl,-1), -1)

Variable containing:
 1.8103
 0.7481
 0.7372
 1.7921
 1.5197
 0.7443
[torch.FloatTensor of size 6]

In [50]:
vae.kl()

AttributeError: 'Nreparametrize' object has no attribute 'log_sigma'