In [1]:
import sys
from hmc_base_pytorch import *
from hmc_unconstrained_pytorch import *
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
import torch.utils.data
from torch import optim
from torchvision import datasets, transforms
from torchvision.utils import save_image


cuda = True
batch_size = 64
epochs = 10
seed = 1
log_interval = 10
z_dim = 20

# Data preparation
torch.manual_seed(seed)
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_data = datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('../data', train=False, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size, shuffle=True, **kwargs)

def binarization(data):
    s = np.random.uniform(size = data.shape)
    out = np.array(s<data).astype(float)
    return out

result = []
for batch_idx, (data, _) in enumerate(train_loader):
    data = data.view(-1, 784).numpy()
    bi_data = binarization(data)
    d = torch.from_numpy(bi_data)
    result.append(d)
    
result_test = []
for batch_idx, (data, _) in enumerate(test_loader):
    data = data.view(-1, 784).numpy()
    bi_data = binarization(data)
    d = torch.from_numpy(bi_data)
    result_test.append(d)

############################################################3
def reparameterize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def log_prior(z):
    dim = z.shape[1]
    mean = torch.zeros(dim).cuda()
    cov = torch.eye(dim).cuda()
    m = MultivariateNormal(mean, cov)
    m.requires_grad=True
    return m.log_prob(z)

def multivariate_normal_logpdf(mean, cov, x):
    mean = mean.cuda()
    cov = cov.cuda()
    k = x.shape[0]
    t1 = -0.5*(x - mean).view(1, k)@torch.inverse(cov)@(x - mean).view(k, 1)
    t2 = 0.5*k*torch.log(2*torch.tensor([math.pi]).cuda()) + 0.5*torch.log(torch.det(cov))
    return t1 - t2

def multivariate_normal_diagonal_logpdf(mean, cov_diag, x):
    mean = mean.cuda()
    cov_diag = cov_diag.cuda()
    n = x.shape[0] # number of samples
    k = x.shape[1] # dimension
    t1 = -0.5*(x - mean)*(1/cov_diag)*(x-mean)
    t1 = torch.sum(t1, dim=1)
    #t2 = 0.5*k*torch.log(2*torch.tensor([math.pi]).cuda()) + 0.5*torch.log(torch.prod(cov_diag,1)).cuda()
    t2 = 0.5*k*torch.log(2*torch.tensor([math.pi]).cuda()) + 0.5*torch.sum(torch.log(cov_diag)).cuda()
    #print("t1: "+str(t1)+"t2: "+str(t2))
    return t1 - t2

class decoder(nn.Module):
    def __init__(self):
        super(decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, 400)
        self.fc2 = nn.Linear(400, 784)
    # single hidden layer
    def forward(self, x):
        #x = x.view(-1, 784)
        h1 = F.relu(self.fc1(x))
        return F.sigmoid(self.fc2(h1))
    
class q_z0(nn.Module):
    def __init__(self):
        super(q_z0, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        #self.fc2 = nn.Linear(300, 300)
        self.fc31 = nn.Linear(300, z_dim)
        self.fc32 = nn.Linear(300, z_dim)
    def forward(self, x):
        x = x.view(-1, 784)
        h1 = F.tanh(self.fc1(x))
        #h2 = F.softplus(self.fc2(h1))
        logvar = self.fc31(h1)
        mu = self.fc32(h1)
        return mu, logvar
    
    
decoder = decoder().to(device)
q_z0 = q_z0().to(device)
r_v = r_v().to(device)
q_v = q_v().to(device)
#mass_diag = torch.randn(z_dim, requires_grad=True)
log_mass_diag = torch.randn(z_dim, requires_grad=True)
q_z0_mean = torch.randn(z_dim, requires_grad=True) 
q_z0_logvar = torch.randn(z_dim, requires_grad=True)


def lower_bound(decoder, q_z0, r_v, data, T):
    batch_size = data.view(-1, 784).shape[0]
    data = data.to(device)
    
    
    mu_z0, logvar_z0 = q_z0(data)
    var_z0 = torch.exp(logvar_z0)
    
    
    z0 = reparameterize(mu_z0, logvar_z0)

    
    
    """
    # get joint probaility p(x, z0)
    log_prior_z0 = log_prior(z0)
    #print("log_prior_z0: " + str(log_prior_z0.shape))
    decoder_output = decoder(z0)
    
    log_likelihood = 0. - F.binary_cross_entropy(decoder_output, data.view(-1, 784).float(), size_average=False, reduce=False)
    #print("log_likelihood: " + str(log_likelihood.shape))
    log_likelihood = torch.sum(log_likelihood, dim = 1)
    #print("log_likelihood: " + str(log_likelihood.shape))
    log_joint = log_prior_z0 + log_likelihood
    #print("log_joint: " + str(log_joint.shape))
    """
    # get log q_z0
    log_q_z0 = multivariate_normal_diagonal_logpdf(mu_z0, var_z0, z0)
    
    L = log_prior_z0 - log_q_z0
    
    print("initial L "+str(L))
    #print(L.shape)

    #print("====================================")
    for i in range(T):
       
            one_log_alpha = torch.tensor([0.]).cuda()
            """
            print("~~~~~~~~~~~`")
            print(log_joint_t[j])
            print(log_r_vt[j])
            print(log_joint[j])
            print(log_q_v1[j])
            """  
            
            L[j] = L[j] + one_log_alpha
            #print("L: "+str(L))
            #alpha = alpha + one_alpha
        #L = L + torch.log(alpha)
    #print("final L: "+str(L.shape))
    #print("~~~~~~~~~~~~~~~~~~~ new L " + str(L) + " ~~~~~~~~~~~~~~~~~~~")
    return torch.sum(L)/batch_size    
                    
                
    
# Train
params1 = list(decoder.parameters())+list(r_v.parameters())#+list(q_z0.parameters())
optimizer1 = optim.Adam(params1, lr=0.0005, weight_decay=5e-5)
#optimizer2 = optim.Adam([log_mass_diag], lr=0.0005, weight_decay=1e-4)
optimizer2 = optim.Adam(q_z0.parameters(), lr=0.00001, weight_decay=1e-3)

for epoch in range(10):
    print("Epoch: "+str(epoch+1))
    file = open("result5_"+str(epoch)+".txt","w")
    file_test = open("result5_test_"+str(epoch)+".txt","w")
    for i in range(len(result)):
        print("++++++++++ batch: " + str(i) + " ++++++++++")

        data = result[i].float()
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        #optimizer3.zero_grad()
        L = lower_bound(decoder, q_z0, r_v, data, 1)
        loss = 0. - L
        loss.backward()
        
        #nn.utils.clip_grad_norm_(q_v.parameters(), 0.5)
        #nn.utils.clip_grad_norm_(q_z0.parameters(), 1)
        #nn.utils.clip_grad_norm_(decoder.parameters(), 1)
        #nn.utils.clip_grad_norm_(r_v.parameters(), 1)
        
        print('weight grad after backward')
        #print(net.conv1.bias.grad)
        #print(q_z0.fc1.bias.grad)
        #print(q_z0.fc31.bias.grad)
        #print(q_z0.fc32.bias.grad)
        optimizer1.step()
        optimizer2.step()
        #optimizer3.step()
        file.write(str(0.-L.item())+"\n") 
        print(L.item())
    file.close()
    for i in range(len(result_test)):
        print("++++++++++ test batch: " + str(i) + " ++++++++++")
        data = result_test[i].float()
        L = lower_bound(decoder, q_z0, r_v, data, 1)
        file_test.write(str(0.-L.item())+"\n")
        print(L.item())
    file_test.close()
    
    sample = torch.randn(64, 20).to(device)
    sample = decoder(sample).cpu()
    save_image(sample.view(64, 1, 28, 28), 'sample5_' + str(epoch) + '.png')
        


ImportError: No module named 'hmc_base_pytorch'