In [None]:
# Code: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/aae/aae.py
# Paper: https://arxiv.org/abs/1511.05644

import argparse
import os
import numpy as np
import math
import itertools

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [None]:
os.makedirs("images", exist_ok=True)
n_epochs = 200            # number of epochs of training
batch_size = 64           # size of the batches
lr = 0.0002               # adam: learning rate
b1 = 0.5                  # adam: decay of first order momentum of gradient
b2 = 0.999                # adam: decay of first order momentum of gradient
n_cpu = 8                 # number of cpu threads to use during batch generation
latent_dim = 10           # dimensionality of the latent code
img_size = 32             # size of each image dimension
channels = 1              # number of image channels
sample_interval = 400     # interval between image sampling

In [None]:
img_shape = (channels, img_size, img_size)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def reparameterization(mu, logvar):
    std = torch.exp(logvar / 2)
    sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), latent_dim))))
    z = sampled_z * std + mu
    return z

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.mu = nn.Linear(512, latent_dim)
        self.logvar = nn.Linear(512, latent_dim)

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        x = self.model(img_flat)
        mu = self.mu(x)
        logvar = self.logvar(x)
        z = reparameterization(mu, logvar)
        return z


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img_flat = self.model(z)
        img = img_flat.view(img_flat.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        validity = self.model(z)
        return validity


# Use binary cross-entropy loss
adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
encoder = Encoder()
decoder = Decoder()
discriminator = Discriminator()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits"""
    # Sample noise: Gaussian noise
    z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, latent_dim))))
    gen_imgs = decoder(z)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        encoded_imgs = encoder(real_imgs)
        decoded_imgs = decoder(encoded_imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
            decoded_imgs, real_imgs
        )

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as discriminator ground truth
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(z), valid)
        fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[Epoch 194/200] [Batch 628/938] [D loss: 0.374363] [G loss: 0.065762]
[Epoch 194/200] [Batch 629/938] [D loss: 0.309727] [G loss: 0.071695]
[Epoch 194/200] [Batch 630/938] [D loss: 0.434890] [G loss: 0.061431]
[Epoch 194/200] [Batch 631/938] [D loss: 0.432157] [G loss: 0.068579]
[Epoch 194/200] [Batch 632/938] [D loss: 0.348890] [G loss: 0.057953]
[Epoch 194/200] [Batch 633/938] [D loss: 0.295368] [G loss: 0.064998]
[Epoch 194/200] [Batch 634/938] [D loss: 0.383879] [G loss: 0.062157]
[Epoch 194/200] [Batch 635/938] [D loss: 0.275718] [G loss: 0.064279]
[Epoch 194/200] [Batch 636/938] [D loss: 0.273116] [G loss: 0.056921]
[Epoch 194/200] [Batch 637/938] [D loss: 0.342675] [G loss: 0.057757]
[Epoch 194/200] [Batch 638/938] [D loss: 0.254984] [G loss: 0.059400]
[Epoch 194/200] [Batch 639/938] [D loss: 0.334623] [G loss: 0.070791]
[Epoch 194/200] [Batch 640/938] [D loss: 0.389221] [G loss: 0.065436]
[Epoch 194/200] [Batch 64

In [None]:
import time
while True: time.sleep(100)

In [None]:
!ls

images	sample_data


In [None]:
!zip -r images.zip images

  adding: images/ (stored 0%)
  adding: images/124000.png (deflated 5%)
  adding: images/124400.png (deflated 5%)
  adding: images/145600.png (deflated 5%)
  adding: images/95600.png (deflated 5%)
  adding: images/154000.png (deflated 5%)
  adding: images/93200.png (deflated 5%)
  adding: images/88000.png (deflated 5%)
  adding: images/800.png (deflated 3%)
  adding: images/60400.png (deflated 5%)
  adding: images/90400.png (deflated 5%)
  adding: images/140400.png (deflated 6%)
  adding: images/155600.png (deflated 5%)
  adding: images/44800.png (deflated 5%)
  adding: images/52800.png (deflated 5%)
  adding: images/174000.png (deflated 5%)
  adding: images/69200.png (deflated 5%)
  adding: images/108000.png (deflated 5%)
  adding: images/92400.png (deflated 5%)
  adding: images/160800.png (deflated 5%)
  adding: images/173600.png (deflated 5%)
  adding: images/152000.png (deflated 5%)
  adding: images/117200.png (deflated 5%)
  adding: images/150000.png (deflated 5%)
  adding: images