In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Bernoulli

import numpy as np
from tools import *
from utils import *
import operator
import itertools
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
np.random.seed(0)
torch.manual_seed(0)
opt= {}
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    opt['device']= torch.device('cuda:0')
    opt['if_cuda']=True
else:
    opt['device']= torch.device('cpu')
    opt['if_cuda']=False

In [2]:
class vae(nn.Module):
    def __init__(self,opt):
        super(vae, self).__init__()
        self.z_dim=10
        self.x_std=0.5
        self.en_fc1 = nn.Linear(784, 600)
        self.en_fc2 = nn.Linear(600, 400)
        self.en_fc3 = nn.Linear(400, 200)
        self.en_fc4_1 = nn.Linear(200, self.z_dim)
        self.en_fc4_2 = nn.Linear(200, self.z_dim)
        self.de_fc1 = nn.Linear(self.z_dim, 200)
        self.de_fc2 = nn.Linear(200, 400)
        self.de_fc3 = nn.Linear(400, 600)
        self.de_fc4 = nn.Linear(600, 784)
        
        self.device=opt['device']
        self.if_cuda=opt['if_cuda']
        self.prior_mu=torch.zeros(self.z_dim, requires_grad=False)
        self.prior_std=torch.ones(self.z_dim, requires_grad=False)
        self.params = list(self.parameters())
        self.optimizer = optim.Adam(self.params, lr=0.0001)


    def posterior(self, x):
        h = F.leaky_relu(self.en_fc1(x))
        h = F.leaky_relu(self.en_fc2(h))
        h = F.leaky_relu(self.en_fc3(h))
        mu = self.en_fc4_1(h)
        log_std = self.en_fc4_2(h)
        return mu, torch.exp(log_std)


    def model(self, z):
        h = F.leaky_relu(self.de_fc1(z))
        h = F.leaky_relu(self.de_fc2(h))
        h = F.leaky_relu(self.de_fc3(h))
        logit = self.de_fc4(h)
#         logit =F.linear(h, self.en_fc1.weight.t())
        return logit
    
    def evaluate(self,x):
        z_mu, z_std=self.posterior(x)
        eps = torch.randn_like(z_mu).to(self.device)
        z=eps.mul(z_std).add_(z_mu)
        logit=self.model(z)
        l = torch.sum(Bernoulli(logits=logit).log_prob(x.view(-1, 784)), dim=1)
#         print(l)
        kl=batch_KL_diag_gaussian_std(z_mu,z_std,self.prior_mu.to(self.device),self.prior_std.to(self.device))
        print(kl)
        loss= torch.mean(-l+kl,dim=0)
        return loss,l,kl,torch.sigmoid(logit)
    
    def loss(self,x):
        z_mu, z_std=self.posterior(x)
        eps = torch.randn_like(z_mu).to(self.device)
        z=eps.mul(z_std).add_(z_mu)
        logit=self.model(z)
        l = torch.sum(Bernoulli(logits=logit).log_prob(x.view(-1, 784)), dim=1)
        kl=batch_KL_diag_gaussian_std(z_mu,z_std,self.prior_mu.to(self.device),self.prior_std.to(self.device))
        loss= torch.mean(-l+kl,dim=0)/np.log(2.)
        return loss
    
    def sample(self):
        z = torch.randn(100, self.z_dim).to(self.device)
        x_sample=Bernoulli(logits=self.model(z)).sample()
        return x_sample.cpu().detach().numpy()

In [3]:
test_data=torchvision.datasets.MNIST('./', train=False, download=False,transform=torchvision.transforms.ToTensor())
test_data_list=[]
for x,y in test_data:
    test_data_list.append(np.rint(x))
    
vae_model = vae(opt).to(opt['device'])
# vae_model.load_state_dict(torch.load("./model_save/binary_vae.pth"))

<All keys matched successfully>

In [4]:
train_data=torchvision.datasets.MNIST('./', train=True, download=False,transform=torchvision.transforms.ToTensor())
train_data_list=[]
for x,y in train_data:
    train_data_list.append(np.rint(x))
vae_model = vae(opt).to(opt['device'])
loss_list=[]
for epoch in range(0,1000):
    if epoch>500:
        if epoch%50==0:
            lr=0.0001/(epoch/100)
            vae_model.optimizer = optim.Adam(vae_model.params, lr)
    for i in range(0,600):
        index=np.random.choice(60000,100)
        batch_data_list=[train_data_list[i] for i in index]
        batch_data=torch.stack(batch_data_list).view(-1,784).to(opt['device'])
        vae_model.optimizer.zero_grad()
        loss = vae_model.loss(batch_data)
        loss.backward()
        loss_list.append(loss.item())
        vae_model.optimizer.step()
    if epoch%50==0 and epoch!=0:
        print('epoch',epoch)
        print('loss',loss.item())
        plt.plot(loss_list)
        plt.show()
        
        x_sample=vae_model.sample()
        show_many(x_sample,10)
        torch.save(vae_model.state_dict(), './model_save/binary_vae.pth')

In [108]:
test_data_t=torch.stack(test_data_list[0:1]).reshape(-1,784)
print(vae_model.loss(test_data_t.to(opt['device'])))

tensor(83.7531, grad_fn=<MeanBackward1>)
