In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Bernoulli

import numpy as np
from tools import *
from distributions import *
from utils import *
import operator
import itertools
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
np.random.seed(0)
torch.manual_seed(0)
opt= {}
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    opt['device']= torch.device('cuda:0')
    opt['if_cuda']=True
else:
    opt['device']= torch.device('cpu')
    opt['if_cuda']=False

In [2]:
class vae(nn.Module):
    def __init__(self,opt):
        super(vae, self).__init__()
        self.z_dim=10
        self.en_fc1 = nn.Linear(784, 600)
        self.en_fc2 = nn.Linear(600, 400)
        self.en_fc3 = nn.Linear(400, 200)
        self.en_fc4_1 = nn.Linear(200, self.z_dim)
        self.en_fc4_2 = nn.Linear(200, self.z_dim)
        self.de_fc1 = nn.Linear(self.z_dim, 200)
        self.de_fc2 = nn.Linear(200, 400)
        self.de_fc3 = nn.Linear(400, 600)
        self.de_fc4_1 = nn.Linear(600, 784)
        self.de_fc4_2 = nn.Linear(600, 784)
        
        self.device=opt['device']
        self.if_cuda=opt['if_cuda']
        self.prior_mu=torch.zeros(self.z_dim, requires_grad=False)
        self.prior_std=torch.ones(self.z_dim, requires_grad=False)
        self.params = list(self.parameters())
        self.optimizer = optim.Adam(self.params, lr=1e-4)


    def posterior(self, x):
        h = F.leaky_relu(self.en_fc1(x))
        h = F.leaky_relu(self.en_fc2(h))
        h = F.leaky_relu(self.en_fc3(h))
        mu = self.en_fc4_1(h)
        log_std = self.en_fc4_2(h)
        return mu, torch.exp(log_std)


    def model(self, z):
        h = F.leaky_relu(self.de_fc1(z))
        h = F.leaky_relu(self.de_fc2(h))
        h = F.leaky_relu(self.de_fc3(h))
        mean = self.de_fc4_1(h)
        log_scale=self.de_fc4_2(h)
        return mean.clamp(min=-0.5 + 1. / 512., max=0.5 - 1. / 512.),log_scale
    
    def loss(self,x):
        z_mu, z_std=self.posterior(x)
        eps = torch.randn_like(z_mu).to(self.device)
        z=eps.mul(z_std).add_(z_mu)
        mean,log_scale=self.model(z)
        l = discretized_logistic_logp(mean,log_scale,x)
        kl=batch_KL_diag_gaussian_std(z_mu,z_std,self.prior_mu.to(self.device),self.prior_std.to(self.device))
        loss= torch.mean(-l+kl,dim=0)
        return loss
    
        
    def sample(self,n=100):
        with torch.no_grad():
            z = torch.randn(n, self.z_dim).to(self.device)
            x_mean,x_scale=self.model(z)
            uniform_noise=torch.clamp(torch.rand_like(x_mean),1e-7,1-1e-7)
            x_sample=x_mean + x_scale * (torch.log(uniform_noise) - torch.log(1-uniform_noise))
            return torch.floor(x_sample*256) /256

In [3]:
train_data=torchvision.datasets.MNIST('../dataset/', train=True, download=False,transform=torchvision.transforms.ToTensor())
train_data_list=[]
for x,y in train_data:
    x=torch.clamp((x*256+torch.rand_like(x))/256,0,1)
    train_data_list.append(x)
    


In [4]:
vae_model = vae(opt).to(opt['device'])
loss_list=[]
for epoch in range(0,1000):
    if epoch>500:
        if epoch%50==0:
            lr=0.0001/(epoch/100)
            vae_model.optimizer = optim.Adam(vae_model.params, lr)
    for i in range(0,600):
        index=np.random.choice(60000,100)
        batch_data_list=[train_data_list[i] for i in index]
        batch_data=torch.stack(batch_data_list).view(-1,784).to(opt['device'])
        vae_model.optimizer.zero_grad()
        loss = vae_model.loss(batch_data)
        loss.backward()
        loss_list.append(loss.item())
        vae_model.optimizer.step()
    if epoch%1==0 and epoch!=0:
        print('epoch',epoch)
        print('loss',loss.item())
        plt.plot(loss_list)
        plt.show()
        
        x_sample=vae_model.sample()
        show_many(x_sample,10)
        torch.save(vae_model.state_dict(), './model_save/vae.pth')

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-8c715b7d3ed8>", line 14, in <module>
    loss.backward()
  File "/usr/local/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2044, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occu

KeyboardInterrupt: 

In [None]:
z_mu, z_std=vae_model.posterior(batch_data)
eps = torch.randn_like(z_mu).to(vae_model.device)
z=eps.mul(z_std).add_(z_mu)
mean,log_scale=vae_model.model(z)
print(mean)
# l = discretized_logistic_logp(mean,log_scale,x)
# kl=batch_KL_diag_gaussian_std(z_mu,z_std,self.prior_mu.to(self.device),self.prior_std.to(self.device))
# loss= torch.mean(-l+kl,dim=0)

In [None]:
with torch.no_grad():
    z = torch.randn(100, vae_model.z_dim)
    x_mean,x_scale=vae_model.model(z)
    uniform_noise=torch.clamp(torch.rand_like(x_mean),1e-7,1-1e-7)
    x_sample=x_mean + x_scale * (torch.log(uniform_noise) - torch.log(1-uniform_noise))
    x_sample=torch.floor(x_sample*256)/256
    show_many(x_mean,10)

In [None]:
print(x_mean)

In [None]:
# test_data=torchvision.datasets.MNIST('../dataset/', train=False, download=False,transform=torchvision.transforms.ToTensor())
# test_data_list=[]
# for x,y in test_data:
#     test_data_list.append(np.rint(x))
    
# vae_model = vae(opt).to(opt['device'])
# # vae_model.load_state_dict(torch.load("./model_save/binary_vae.pth"))