In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from nnutils import create_var, create_onehot, log_standard_categorical

In [2]:
class NN(nn.Module):
    def __init__(self, input_size, latent_size, y_size):
        self.fc1 = nn.Linear(input_size, latent_size)
        self.fc2 = nn.Linear(latent_size, y_size)
    
    def forward(self, x):
        x = F.LeakyReLU(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class NNVAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, y_size, alpha):
        self.y_size = y_size
        self.alpha = alpha
        
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, y_size),
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(),
        )
        self.z_mean = nn.Linear(hidden_size, latent_size)
        self.z_var = nn.Linear(hidden_size, latent_size)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_size+y_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, input_size),
        )
        
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        
    def rsample(self, h):
        batch_size = h.size(0)
        z_mean = self.z_mean(h)
        z_log_var = -torch.abs(self.z_var(h))
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
        epsilon = create_var(torch.randn_like(z_mean))
        z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
        return z_vecs, kl_loss 
        
    
    def decode(self, z, y):
        latent = torch.cat((z, y), 1)
        x_hat = self.decoder(latent)
        return x_hat
    
    
    def compute_reconstruction_loss(x, z, y):
        x_hat = decode(z, y)
        return self.cross_entropy_loss(x_hat, x)
        

    def forward(self, x, y=None):
        is_labeled = False if y is None else True
        loss, pred_loss, pred_acc = 0, 0, 0
        
        h = self.encoder(x)
        y_hat = self.classifier(x)
        z, kl_div = self.rsample(h)
        
        logy = torch.mean(log_standard_categorical(y_hat))
        loss = logy + kl_div
        
        if is_labeled:
            y = create_var(y)
            pred_loss = self.cross_entropy_loss(y_hat, y) * self.alpha
            pred_acc = float((y_hat == y).sum().item()) / y.size(0)
            recon_loss = self.compute_reconstruction_loss(x, z, y)
            loss += pred_loss + recon_loss
        else:
            for i in range(self.y_size):
                y = create_onehot(len(x_batch), self.y_size, i)
                recon_loss = self.compute_reconstruction_loss(x, z, y) 
                loss += recon_loss * torch.mean(y_hat[:, i])     
            
            y_hat_entropy = torch.sum(y_hat * torch.log(y_hat + 1e-8))
            loss += y_hat_entropy
            
        return loss, pred_loss, pred_acc