In [35]:
# ! python main.py evaluation --batch_size=128

import torch
import torch.nn.functional as F
import numpy as np

from data.dataset import MNIST
from torch.utils.data import DataLoader
from torchnet import meter
from tqdm import tqdm

from torch.optim import Adam

class VAE:
    def __init__(self, in_dim, encoder_width, decoder_width, latent_dim, device=torch.device('cpu')):
        # device
        self.name = 'VAE'
        self.device = device
        
        # initialize encoder/decoder weights and biases
        self.weights, self.biases = self.init_vae_params(in_dim, encoder_width, decoder_width, latent_dim)
        
        # build the VAE model
        
        X = torch.randn(7,784)
        print(X.size())
        
        z_mean, z_std = self._encoding(X, self.weights, self.biases)
        
        print(z_mean.size())
        print(z_std.size())
        
        Xstar = self._decoding(z_mean, self.weights, self.biases)
        
        print(Xstar.size())
        
    def _encoding(self, X, weights, biases):
        # Kingma Supplemtary C.2
        output = torch.matmul(X, weights['encoder_hidden']) + biases['encoder_hidden']
        output = torch.tanh(output) 
        mean_output = torch.matmul(output, weights['latent_mean']) + biases['latent_mean']
        std_output = torch.matmul(output, weights['latent_std']) + biases['latent_std']
        
        return mean_output, std_output
        
    def _decoding(self, Z, weights, biases):
        output = torch.matmul(Z, weights['decoder_hidden']) + biases['decoder_hidden']
        output = torch.tanh(output)
        Xstar = torch.matmul(output, weights['decoder_out']) + biases['decoder_out']
        
        return Xstar
        
        

#         # config learnable variables
#         self.activation = {'relu':F.relu, 'sigmoid': torch.sigmoid, 'tanh': F.tanh}[activation]            
#         self.weights, self.biases = self.init_nn_weights()
        
#         # config dataset
#         mnist = MNIST()
#         self.train_data = mnist.get(train=True) 
#         self.test_data = mnist.get(train=False)
        
#         self.criterion = torch.nn.CrossEntropyLoss()                 

    def init_vae_params(self, in_dim, encoder_width, decoder_width, latent_dim):
        
        weights = {
            'encoder_hidden': self.xavier_init(in_dim, encoder_width),
            'latent_mean': self.xavier_init(encoder_width, latent_dim),
            'latent_std' : self.xavier_init(encoder_width, latent_dim),
            'decoder_hidden': self.xavier_init(latent_dim, decoder_width),
            'decoder_out': self.xavier_init(decoder_width, in_dim),
        }
        
        biases = {
            'encoder_hidden': self.xavier_init(1, encoder_width),
            'latent_mean': self.xavier_init(1, latent_dim),
            'latent_std' : self.xavier_init(1, latent_dim),
            'decoder_hidden': self.xavier_init(1, decoder_width),
            'decoder_out': self.xavier_init(1, in_dim),
        }
            
        return weights, biases
    
    def xavier_init(self, in_d, out_d):
        xavier_stddev = np.sqrt(2.0/(in_d + out_d))
        W = torch.normal(size=(in_d, out_d), mean=0.0, std=xavier_stddev, requires_grad=True, device=self.device)
        return W
    
    
model = VAE(784, 512, 512, 256)


torch.Size([7, 784])
torch.Size([7, 256])
torch.Size([7, 256])
torch.Size([7, 784])
