In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [2]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
#         self.fc2_y = torch.cat((self.fc2, y), 1)
        self.fc31 = nn.Linear(h_dim2 + 1, z_dim)
        self.fc32 = nn.Linear(h_dim2 + 1, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x, y):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        h = torch.cat((h, y), 1)
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784), y.view(-1, 1))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [3]:
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [4]:
# lr = 0.0002

# # build model

# mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
# discriminator = Discriminator(d_input_dim=mnist_dim).cuda()

# criterion = nn.BCELoss()
# criterion.cuda()


In [5]:
mnist_dim = 784
lr = 0.0002
vae = VAE(x_dim=mnist_dim, h_dim1= 512, h_dim2=256, z_dim=2).cuda()   # generator
D = Discriminator(d_input_dim=mnist_dim).cuda()

G_optimizer = optim.Adam(vae.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)
criterion = nn.BCELoss().cuda()

In [6]:
def D_train(x):
    D.zero_grad()
    
    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
#     print("D real output: ", D_output)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    y_fake = Variable(torch.zeros(bs, 1).to(device))
    recon_batch, mu, log_var = vae(x_real, y_fake)

    D_output = D(recon_batch)
#     print("D fake output: ", D_output)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output
#     print("D fake score:", D_fake_score)
    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()


In [7]:
# def loss_function(recon_x_0, recon_x_1, x, mu_0, log_var_0, mu_1, log_var_1, prob):
#     BCE_0 = F.binary_cross_entropy(recon_x_0, x.view(-1, 784), reduction='sum') 
#     BCE_1 = F.binary_cross_entropy(recon_x_1, x.view(-1, 784), reduction='sum') 
# #     print("bce loss:", BCE_1)
#     KLD_0 = -0.5 * torch.sum(1 + log_var_0 - mu_0.pow(2) - log_var_0.exp())
#     KLD_1 = -0.5 * torch.sum(1 + log_var_1 - mu_1.pow(2) - log_var_1.exp())
#     return BCE_0 * (1 - prob) + BCE_1 * prob + KLD_0 * (1 - prob) + KLD_1 * prob

def loss_function(recon_x_0, recon_x_1, x, mu_0, log_var_0, mu_1, log_var_1, prob):
#     BCE_0 = F.binary_cross_entropy(recon_x_0, x.view(-1, 784), reduction='sum')
#     BCE_1 = F.binary_cross_entropy(recon_x_1, x.view(-1, 784), reduction='sum')
    BCE_0 = F.binary_cross_entropy(recon_x_0, x.view(-1, 784), reduction='none') 
    BCE_1 = F.binary_cross_entropy(recon_x_1, x.view(-1, 784), reduction='none')
#     print("BCE_0 loss shape:", BCE_0.shape)
#     print("BCE_0 loss:", BCE_0)
#   
    KLD_0 = -0.5 * torch.sum(1 + log_var_0 - mu_0.pow(2) - log_var_0.exp(), dim = 1)
    KLD_1 = -0.5 * torch.sum(1 + log_var_1 - mu_1.pow(2) - log_var_1.exp(), dim = 1)
#     print("KLD_0 loss shape:", KLD_0.shape)
#     print("KLD_0 loss:", KLD_0)
#     print("kld * prob", KLD_0 * prob)
    return torch.sum(BCE_0 * (1 - prob) + BCE_1 * prob) + torch.sum(KLD_0 * (1 - prob) + KLD_1 * prob)

def G_train(x):
    vae.zero_grad()        
     
    y_fake = Variable(torch.zeros(bs, 1).to(device))
    y_real = Variable(torch.ones(bs, 1).to(device))
    
    data =  Variable(x.view(-1, mnist_dim).to(device))
    recon_x_0, mu_0, log_var_0 = vae(data, y_fake)
    recon_x_1, mu_1, log_var_1 = vae(data, y_real)
    
    with torch.no_grad():
        prob = D(data)
#     print("prob shape:", prob)
    loss = loss_function(recon_x_0, recon_x_1, data, mu_0, log_var_0, mu_1, log_var_1, prob)
    loss.backward()
    G_optimizer.step()
#     print("loss: ", loss.data.item())
    return loss.data.item() / bs

In [8]:
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []
fixed_z_ = torch.randn((5 * 5, 100))    # fixed noise
with torch.no_grad():
       fixed_z_ = Variable(fixed_z_.cuda())
# fixed_z_ = Variable(fixed_z_.cuda(), volatile=True)

In [9]:
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iter')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [10]:
def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
#     z_ = torch.randn((5*5, 100))
# #     z_ = Variable(z_.cuda(), volatile=True)
#     with torch.no_grad():
#         z_ = Variable(z_.cuda())
    
#     y = Variable(torch.ones(bs, 1).to(device))
    
#     vae.eval()
#     if isFix:
#         test_images = vae(fixed_z_)
#     else:
#         test_images = vae(z_)
#     vae.train()

    z = torch.randn(64, 2).cuda()
    sample = vae.decoder(z).cuda()
    
    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k, :].cpu().data.view(28, 28).numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [11]:
import time
n_epoch = 200
start_time = time.time()
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    epoch_start_time = time.time()
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
    
    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time
    
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
    
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    fixed_p = '/samples/MNIST_AAVAE2_' + str(epoch) + '.png'
    if (epoch + 1) % 20 == 0:
        show_pic(fixed_p)



[1/200]: loss_d: 0.040, loss_g: 663835.562
[2/200]: loss_d: 0.000, loss_g: 2848.756
[3/200]: loss_d: 0.000, loss_g: 2848.670
[4/200]: loss_d: 0.000, loss_g: 2848.714
[5/200]: loss_d: 0.000, loss_g: 2848.722
[6/200]: loss_d: 0.000, loss_g: 2850.637
[7/200]: loss_d: 0.000, loss_g: 2849.488
[8/200]: loss_d: 0.000, loss_g: 2848.573
[9/200]: loss_d: 0.000, loss_g: 2848.573
[10/200]: loss_d: 0.000, loss_g: 2848.573
[11/200]: loss_d: 0.000, loss_g: 2848.573
[12/200]: loss_d: 0.000, loss_g: 2848.573
[13/200]: loss_d: 0.000, loss_g: 2848.573
[14/200]: loss_d: 0.000, loss_g: 2848.573
[15/200]: loss_d: 0.000, loss_g: 2848.573
[16/200]: loss_d: 0.000, loss_g: 2848.573
[17/200]: loss_d: 0.000, loss_g: 2848.573
[18/200]: loss_d: 0.000, loss_g: 2848.574
[19/200]: loss_d: 0.000, loss_g: 2848.574


NameError: name 'show_pic' is not defined

In [None]:
end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)

if not os.path.isdir('MNIST_AAVAE_results'):
    os.mkdir('MNIST_AAVAE_results')
if not os.path.isdir('MNIST_AAVAE_results/Random_results'):
    os.mkdir('MNIST_AAVAE_results/Random_results')
if not os.path.isdir('MNIST_AAVAE_results/Fixed_results'):
    os.mkdir('MNIST_AAVAE_results/Fixed_results')

In [None]:
print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), n_epoch, total_ptime))
print("Training finish!... save training results")
torch.save(vae.state_dict(), "MNIST_AAVAE_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_AAVAE_results/discriminator_param.pkl")
with open('MNIST_AAVAE_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path='MNIST_AAVAE_results/MNIST_AAVAE_train_hist.png')

images = []
for e in range(n_epoch):
    img_name = 'MNIST_AAVAE_results/Fixed_results/MNIST_AAVAE_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_AAVAE_results/generation_animation.gif', images, fps=5)

In [140]:
def show_pic(filename):
    with torch.no_grad():
        z = torch.randn(64, 2).cuda()
        sample = vae.decoder(z).cuda()
        save_image(sample.view(64, 1, 28, 28), filename)

In [145]:
fixed_p = 'MNIST_AAVAE2debug1.png'
show_pic(fixed_p)

