In [2]:
import sys
sys.path.append('F:\Cambridge\Project\MHMC-for-VAE\change_of_variable')
sys.path.append('F:\Cambridge\Project\MHMC-for-VAE\hmc_pytorch')
from change_of_variable_pytorch import * 
from hmc_base_pytorch import *
from hmc_unconstrained_pytorch import *
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt

import torch.utils.data
from torch import optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import time

In [3]:
cuda = True
batch_size = 128
epochs = 10
seed = 1
log_interval = 10
z_dim = 20

device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

def reparameterize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def log_prior(z):
    dim = z.shape[1]
    mean = torch.zeros(dim).cuda()
    cov = torch.eye(dim).cuda()
    m = MultivariateNormal(mean, cov)
    m.requires_grad=True
    return m.log_prob(z)

def multivariate_normal_logpdf(mean, cov, x):
    mean = mean.cuda()
    cov = cov.cuda()
    k = x.shape[0]
    #cov = cov + (1e-6*torch.eye(k)).cuda()
    t1 = -0.5*(x - mean).view(1, k)@torch.inverse(cov)@(x - mean).view(k, 1)
    #t21 = 0.5*k*torch.log(2*torch.tensor([math.pi]).cuda())
    #t22 = 0.5*torch.log(torch.det(cov))
    #t2 = t21 + t22
    t2 = 0.5*k*torch.log(2*torch.tensor([math.pi]).cuda()) + 0.5*torch.log(torch.det(cov))
    return t1 - t2

class decoder(nn.Module):
    def __init__(self):
        super(decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, 400)
        self.fc2 = nn.Linear(400, 784)
    # single hidden layer
    def forward(self, x):
        #x = x.view(-1, 784)
        h1 = F.relu(self.fc1(x))
        return F.sigmoid(self.fc2(h1))
    
class q_z0(nn.Module):
    def __init__(self):
        super(q_z0, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 300)
        self.fc31 = nn.Linear(300, z_dim)
        self.fc32 = nn.Linear(300, z_dim)
    def forward(self, x):
        x = x.view(-1, 784)
        h1 = F.softplus(self.fc1(x))
        h2 = F.softplus(self.fc2(h1))
        logvar = self.fc31(h2)
        mu = self.fc32(h2)
        return mu, logvar
    
class r_v(nn.Module):
    def __init__(self):
        super(r_v, self).__init__()
        self.fc1 = nn.Linear(z_dim + 784, 300)
        self.fc21 = nn.Linear(300, z_dim)
        self.fc22 = nn.Linear(300, z_dim)
    def forward(self, x):
        x = x.view(-1, 784 + z_dim)
        h1 = F.softplus(self.fc1(x))
        logvar = self.fc21(h1)
        mu = self.fc22(h1)
        return mu, logvar

class q_v(nn.Module):
    def __init__(self):
        super(q_v, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        # no need to output mu because the mean of momentum is default 0
        self.fc21 = nn.Linear(300, z_dim)
    def forward(self, x):
        x = x.view(-1, 784)
        h1 = F.softplus(self.fc1(x))
        logvar = self.fc21(h1)
        return logvar

In [4]:
decoder = decoder().to(device)
q_z0 = q_z0().to(device)
r_v = r_v().to(device)
q_v = q_v().to(device)
log_mass_diag = torch.randn(z_dim, requires_grad=True)

In [51]:
def lower_bound(decoder, q_z0, r_v, q_v, data, log_mass_diag, T):
    data = data.to(device)
    #mu_z0, logvar_z0 = q_z0(data)
    #var_z0 = torch.exp(logvar_z0)
    #print(logvar_z0.shape)

    # sample z0
    
    mu_z0 = torch.zeros(128,20).cuda()
    logvar_z0 = torch.zeros(128,20).cuda()
    var_z0 = torch.ones(128,20).cuda()
    z0 = reparameterize(mu_z0, logvar_z0)
    #print("z0: " + str(z0.shape))
    #print(z0)

    # get joint probaility p(x, z0)
    log_prior_z0 = log_prior(z0)
    #print("log_prior_z0: " + str(log_prior_z0.shape))
    decoder_output = decoder(z0)
    #print("decoder_output: " + str(decoder_output.shape))
    log_likelihood = 0. - F.binary_cross_entropy(decoder_output, data.view(-1, 784), size_average=False, reduce=False)
    #print("log_likelihood: " + str(log_likelihood.shape))
    log_likelihood = torch.sum(log_likelihood, dim = 1)
    #print(log_likelihood)
    #print(log_prior_z0)
    #print("log_likelihood: " + str(log_likelihood.shape))
    log_joint = log_prior_z0 + log_likelihood
    #print(log_joint)
    #print("log_joint: " + str(log_joint.shape))

    # get log q_z0
    log_q_z0 = torch.zeros(0).cuda()
    for i in range(batch_size):
        one_cov = torch.diag(var_z0[i])
        #m = MultivariateNormal(mu_z0[i], one_cov)
        #one_q_z0 = m.log_prob(z0[i]).view(1)
        one_q_z0 = multivariate_normal_logpdf(mu_z0[i], one_cov, z0[i])
        #print("one q z0: " + str(one_q_z0))
        log_q_z0 = torch.cat((log_q_z0,one_q_z0),0)
    #print("log_q_z0: " + str(log_q_z0.shape))

    #print("?????????????")
    #print(log_joint.shape)
    #print(log_q_z0.shape)
    
    # initial L for 128 samples
    L = log_joint - log_q_z0.view(batch_size)
    #print(L)
    #print(L.shape)
    L = torch.sum(L)
    #print("L "+str(L))
    #print(L.shape)

    #print("====================================")
    for i in range(T):
        # sample v1
        mass_diag = torch.exp(log_mass_diag)
        mass_matrix = torch.diag(mass_diag)
        mass_matrix.cuda()
        var_v1_matrix = torch.inverse(mass_matrix)
        var_v1_diag = torch.diag(var_v1_matrix)
        logvar_v1_diag = torch.log(var_v1_diag)

        logvar_v1 = logvar_v1_diag.repeat(batch_size,1).cuda()
        #print(logvar_v1)
        mu_v1 = torch.zeros(logvar_v1.shape[0], logvar_v1.shape[1]).cuda()
        v1 = reparameterize(mu_v1, logvar_v1)
        #print(v1)
        
        # get q_v1
        log_q_v1 = torch.zeros(0).cuda()
        for i in range(batch_size):
            one_cov = var_v1_matrix
            #m = MultivariateNormal(mu_v1[i], one_cov)
            #one_q_v1 = m.log_prob(v1[i]).view(1)
            one_q_v1 = multivariate_normal_logpdf(mu_v1[i], one_cov, v1[i])
            log_q_v1 = torch.cat((log_q_v1,one_q_v1),0)
        #print("log_q_v1: "+str(log_q_v1.shape))
        

        log_joint_t = torch.zeros(0).cuda() # list of all the joint
        log_r_vt = torch.zeros(0).cuda()
        alpha = torch.tensor([0.]).cuda() # lower bound for each batch (128 samples)
        for j in range(batch_size):
            def energy_function(z, cache):
                z.retain_grad()
                z = z.view(1, z.shape[0])
                z = z.cuda()
                one_log_prior = log_prior(z)
                decoder_output = decoder(z)
                one_log_likelihood = 0. - F.binary_cross_entropy(decoder_output, data.view(-1, 784)[j], size_average=False, reduce=False)
                #print(one_log_likelihood.shape)
                one_log_likelihood = torch.sum(one_log_likelihood, dim = 1)
                one_log_joint = one_log_prior + one_log_likelihood
                return 0 - one_log_joint
            sampler = IsotropicHmcSampler(energy_function, energy_grad=None, prng=None,
                                          mom_resample_coeff=1., dtype=np.float64)
            init = torch.zeros(z_dim).cuda()
            
            pos_samples, mom_samples, ratio = sampler.get_samples(init, 0.1, 3, 2, mass_matrix, mom = v1[j].view(z_dim))
            #print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`")
            #print(pos_samples[1].shape)

            # get joint probaility p(x, zt)
            zt = pos_samples[1].cuda()
            vt = mom_samples[1].cuda()
            zt = zt.view(1, zt.shape[0])
            vt = vt.view(vt.shape[0])
            
            # get joint probaility p(x, zt)
            one_log_prior_zt = log_prior(zt)
            #print("one_log_prior_zt: " + str(one_log_prior_zt.shape))
            one_decoder_output_t = decoder(zt)
            #print("one_decoder_output_t: " + str(one_decoder_output_t.shape))
            one_log_likelihood_t = 0. - F.binary_cross_entropy(one_decoder_output_t, data.view(-1, 784)[j], size_average=False, reduce=False)
            one_log_likelihood_t = torch.sum(one_log_likelihood_t, dim = 1)
            #print("one_log_likelihood_t: " + str(one_log_likelihood_t.shape))
            one_log_joint_t = one_log_prior_zt + one_log_likelihood_t
            #print("one_log_joint_t: " + str(one_log_joint_t.shape))
            log_joint_t = torch.cat((log_joint_t, one_log_joint_t), 0)
            #print(one_log_joint_t)

            # get r_vt
            d = data.view(-1, 784)[j].view(1, 784)
            one_new_data = torch.cat((d, zt), 1) # append data with zt
            one_mu_vt, one_logvar_vt = r_v(one_new_data)
            one_var_vt = torch.exp(one_logvar_vt)
            one_mu_vt = one_mu_vt.view(one_mu_vt.shape[1])
            one_cov = torch.diag(one_var_vt.view(one_var_vt.shape[1]))
            #m = MultivariateNormal(one_mu_vt, one_cov)
            #one_log_r_vt = m.log_prob(vt).view(1)
            one_log_r_vt = multivariate_normal_logpdf(one_mu_vt, one_cov, vt)
            #print(one_log_r_vt)
            log_r_vt = torch.cat((log_r_vt, one_log_r_vt), 0)
            

            # get L for each sample
            one_log_alpha = log_joint_t[j] + log_r_vt[j] - log_joint[j] - log_q_v1[j]
            #print(log_joint_t[j])
            #print(log_r_vt[j])
            #print(log_joint[j])
            #print(log_q_v1[j])
            #print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
            #print("one log alpha: "+str(one_log_alpha))
            #one_log_alpha = torch.log(one_alpha)
            L = L + one_log_alpha
            #alpha = alpha + one_alpha
        #L = L + torch.log(alpha)

    #print("~~~~~~~~~~~~~~~~~~~ new L " + str(L) + " ~~~~~~~~~~~~~~~~~~~")
    return L

                

In [52]:
params1 = list(decoder.parameters())+list(q_z0.parameters())+list(r_v.parameters())+list(q_v.parameters())
optimizer1 = optim.Adam(params1, lr=0.0005)
optimizer2 = optim.Adam([log_mass_diag], lr=0.0005)

for batch_idx, (data, _) in enumerate(train_loader):
    print("++++++++++ " + str(batch_idx) + " ++++++++++")
    optimizer1.zero_grad()
    optimizer2.zero_grad()
    L = lower_bound(decoder, q_z0, r_v, q_v, data, log_mass_diag, 1)
    loss = 0. - L
    loss.backward()
    optimizer1.step()
    optimizer2.step()
    print(L)
print(L)

++++++++++ 0 ++++++++++


  "Please ensure they have the same size.".format(target.size(), input.size()))


tensor([-73820.7656], device='cuda:0')
++++++++++ 1 ++++++++++
tensor([-69017.7812], device='cuda:0')
++++++++++ 2 ++++++++++


KeyboardInterrupt: 