In [1]:
import argparse
import os
import torch
import numpy as np
import matplotlib
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch import nn
from torch.autograd import Variable
from torch import nn, optim
from torch.nn import functional as F

In [2]:
%matplotlib inline

In [3]:
device = torch.device("mps")

In [4]:
seed = 1
batch_size = 128
nz = 100
lr = 0.001
b1 = 0.5
b2 = 0.9999
epochs = 100
log_interval = 10

In [5]:
torch.manual_seed(seed)

<torch._C.Generator at 0x109f786b0>

In [6]:
curr_dirname = os.path.dirname('.')
data_directory = os.path.join(curr_dirname, "data")
gan_results_directory_name = "results/gan"
os.makedirs(gan_results_directory_name, exist_ok=True)
gan_results_directory = os.path.join(curr_dirname, gan_results_directory_name)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        data_directory, train=True, download=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
    ),
    batch_size=batch_size,
    shuffle=True
)

In [7]:
img_shape = (1, 28, 28)

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # create a reusable block with #in_feat input features and #out_feat output features
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(nz, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # create a reusable block with #in_feat input features and #out_feat output features
        def block(in_feat, out_feat):
            layers = [nn.Linear(in_feat, out_feat)]
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(int(np.prod(img_shape)), 512),
            *block(512, 256),
            *block(256, 128),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        out = self.model(img_flat)
        return out


In [9]:
# loss function
adversarial_loss = torch.nn.BCELoss()

In [10]:
# initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [11]:
# optimizers
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [12]:
Tensor = torch.FloatTensor

In [13]:
# ----------
#  Training
# ----------
gan_losses = []
for epoch in range(1, epochs+1):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)

        # Adversarial ground truths
        valid = torch.ones([real_imgs.size(0), 1], device=device)
        fake = torch.zeros([real_imgs.size(0), 1], device=device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_generator.zero_grad()

        # Sample noise as generator input
        z = torch.randn(real_imgs.shape[0], nz, device=device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # classify images using discriminator
        classifications = discriminator(gen_imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(classifications, valid)

        g_loss.backward()
        optimizer_generator.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_discriminator.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        # detaching since the backprop does not need to run on the generator
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        # take the average of loss against generated images and loss against real images
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_discriminator.step()

        if i % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tDiscriminator Loss: {:.6f}\tGenerator Loss: {:.6f}".format(
                    epoch,
                    i * len(real_imgs),
                    len(dataloader.dataset),
                    100.0 * i / len(dataloader),
                    d_loss.item(),
                    g_loss.item()
                )
            )
    gan_losses.append((g_loss, real_loss, fake_loss, d_loss))
    with torch.no_grad():
        n_images = 64
        sample = torch.randn(n_images, nz).to(device)
        sample = generator(sample).cpu()
        save_image(
            sample.view(n_images, *img_shape),
            gan_results_directory + "/sample_" + str(epoch) + ".png",
        )







































































































In [21]:
torch.save(generator.state_dict(), os.path.join(curr_dirname, "gan_generator.pt"))
torch.save(discriminator.state_dict(), os.path.join(curr_dirname, "gan_discriminator.pt"))

In [14]:
vae_results_directory_name = "results/vae"
os.makedirs(vae_results_directory_name, exist_ok=True)
vae_results_directory = os.path.join(curr_dirname, vae_results_directory_name)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_directory, train=True, download=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0],[1])])),
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_directory, train=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0],[1])])),
    batch_size=batch_size,
    shuffle=False
)

In [15]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc31 = nn.Linear(200, nz)
        self.fc32 = nn.Linear(200, nz)
        self.fc4 = nn.Linear(nz, 200)
        self.fc5 = nn.Linear(200, 400)
        self.fc6 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc31(h2), self.fc32(h2)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc4(z))
        h4 = F.relu(self.fc5(h3))
        return torch.sigmoid(self.fc6(h4))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [16]:
vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=lr)

In [19]:
vae_train_losses = []
# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train_vae(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = vae_loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item() / len(data),
                )
            )
    vae_train_losses.append(train_loss)
    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(train_loader.dataset)
        )
    )

vae_test_losses = []
def test_vae(epoch):
    vae.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = vae(data)
            test_loss += vae_loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat(
                    [data[:n], recon_batch.view(batch_size, 1, 28, 28)[:n]]
                )
                save_image(
                    comparison.cpu(),
                    vae_results_directory + "/reconstruction_" + str(epoch) + ".png",
                    nrow=n,
                )

    test_loss /= len(test_loader.dataset)
    vae_test_losses.append(test_loss)
    print("====> Test set loss: {:.4f}".format(test_loss))


In [20]:
## VAE TRAINING & TESTING
for epoch in range(1, epochs + 1):
    train_vae(epoch)
    test_vae(epoch)
    with torch.no_grad():
        sample = torch.randn(64, nz).to(device)
        sample = vae.decode(sample).cpu()
        save_image(
            sample.view(64, 1, 28, 28),
            vae_results_directory + "/sample_" + str(epoch) + ".png",
        )

====> Epoch: 1 Average loss: 142.1480
====> Test set loss: 131.1435
====> Epoch: 2 Average loss: 125.2136
====> Test set loss: 119.2233
====> Epoch: 3 Average loss: 117.7575
====> Test set loss: 114.7346


====> Epoch: 4 Average loss: 114.1963
====> Test set loss: 112.0254
====> Epoch: 5 Average loss: 111.5654
====> Test set loss: 109.7251
====> Epoch: 6 Average loss: 109.5441
====> Test set loss: 107.9999


====> Epoch: 7 Average loss: 107.8994
====> Test set loss: 106.8758
====> Epoch: 8 Average loss: 106.6681
====> Test set loss: 105.6139
====> Epoch: 9 Average loss: 105.6663
====> Test set loss: 104.9686


====> Epoch: 10 Average loss: 104.8578
====> Test set loss: 104.6448
====> Epoch: 11 Average loss: 104.2126
====> Test set loss: 103.3818
====> Epoch: 12 Average loss: 103.6321
====> Test set loss: 103.4323
====> Epoch: 13 Average loss: 103.2575
====> Test set loss: 103.0581


====> Epoch: 14 Average loss: 102.7961
====> Test set loss: 102.8155
====> Epoch: 15 Average loss: 102.4266
====> Test set loss: 102.5111
====> Epoch: 16 Average loss: 102.1265
====> Test set loss: 102.1825


====> Epoch: 17 Average loss: 101.8560
====> Test set loss: 102.1841
====> Epoch: 18 Average loss: 101.6130
====> Test set loss: 102.0249
====> Epoch: 19 Average loss: 101.3736
====> Test set loss: 101.8819


====> Epoch: 20 Average loss: 101.1383
====> Test set loss: 101.9012
====> Epoch: 21 Average loss: 101.0019
====> Test set loss: 101.4864
====> Epoch: 22 Average loss: 100.7807
====> Test set loss: 101.4435


====> Epoch: 23 Average loss: 100.6046
====> Test set loss: 101.2037
====> Epoch: 24 Average loss: 100.4606
====> Test set loss: 101.1397
====> Epoch: 25 Average loss: 100.3149
====> Test set loss: 101.1570
====> Epoch: 26 Average loss: 100.1753
====> Test set loss: 101.2160


====> Epoch: 27 Average loss: 100.0626
====> Test set loss: 100.9570
====> Epoch: 28 Average loss: 99.9350
====> Test set loss: 100.8416
====> Epoch: 29 Average loss: 99.8147
====> Test set loss: 100.7697


====> Epoch: 30 Average loss: 99.7131
====> Test set loss: 100.7968
====> Epoch: 31 Average loss: 99.6359
====> Test set loss: 100.8541
====> Epoch: 32 Average loss: 99.4969
====> Test set loss: 100.5499


====> Epoch: 33 Average loss: 99.4028
====> Test set loss: 100.5459
====> Epoch: 34 Average loss: 99.3427
====> Test set loss: 100.5425
====> Epoch: 35 Average loss: 99.2496
====> Test set loss: 100.5855
====> Epoch: 36 Average loss: 99.1794


====> Test set loss: 100.4709
====> Epoch: 37 Average loss: 99.0959
====> Test set loss: 100.3382
====> Epoch: 38 Average loss: 98.9398
====> Test set loss: 100.3857
====> Epoch: 39 Average loss: 98.9065
====> Test set loss: 100.3399


====> Epoch: 40 Average loss: 98.8372
====> Test set loss: 100.1973
====> Epoch: 41 Average loss: 98.7671
====> Test set loss: 100.0409
====> Epoch: 42 Average loss: 98.7230
====> Test set loss: 99.7114


====> Epoch: 43 Average loss: 98.6690
====> Test set loss: 100.0089
====> Epoch: 44 Average loss: 98.5773
====> Test set loss: 99.9681
====> Epoch: 45 Average loss: 98.5515
====> Test set loss: 100.0191


====> Epoch: 46 Average loss: 98.4877
====> Test set loss: 100.0435
====> Epoch: 47 Average loss: 98.4565
====> Test set loss: 99.9800
====> Epoch: 48 Average loss: 98.3852
====> Test set loss: 100.0293
====> Epoch: 49 Average loss: 98.3252
====> Test set loss: 99.9252


====> Epoch: 50 Average loss: 98.2882
====> Test set loss: 99.9337
====> Epoch: 51 Average loss: 98.2381
====> Test set loss: 100.0044
====> Epoch: 52 Average loss: 98.1352
====> Test set loss: 99.8451


====> Epoch: 53 Average loss: 98.1202
====> Test set loss: 99.9504
====> Epoch: 54 Average loss: 98.0823
====> Test set loss: 99.6501
====> Epoch: 55 Average loss: 98.0540
====> Test set loss: 99.8217


====> Epoch: 56 Average loss: 98.0181
====> Test set loss: 99.6651
====> Epoch: 57 Average loss: 97.9367
====> Test set loss: 99.6775
====> Epoch: 58 Average loss: 97.9230
====> Test set loss: 99.8521
====> Epoch: 59 Average loss: 97.8751
====> Test set loss: 99.7096


====> Epoch: 60 Average loss: 97.8341
====> Test set loss: 99.7084
====> Epoch: 61 Average loss: 97.8035
====> Test set loss: 99.8080
====> Epoch: 62 Average loss: 97.7938
====> Test set loss: 99.8251


====> Epoch: 63 Average loss: 97.7758
====> Test set loss: 99.7492
====> Epoch: 64 Average loss: 97.6985
====> Test set loss: 99.8284
====> Epoch: 65 Average loss: 97.6843
====> Test set loss: 99.5167


====> Epoch: 66 Average loss: 97.6188
====> Test set loss: 99.6312
====> Epoch: 67 Average loss: 97.5943
====> Test set loss: 99.6230
====> Epoch: 68 Average loss: 97.5548
====> Test set loss: 99.5043


====> Epoch: 69 Average loss: 97.5108
====> Test set loss: 99.4490
====> Epoch: 70 Average loss: 97.4987
====> Test set loss: 99.7581
====> Epoch: 71 Average loss: 97.4724
====> Test set loss: 99.2469
====> Epoch: 72 Average loss: 97.4409
====> Test set loss: 99.5610


====> Epoch: 73 Average loss: 97.4294
====> Test set loss: 99.3927
====> Epoch: 74 Average loss: 97.3899
====> Test set loss: 99.4246
====> Epoch: 75 Average loss: 97.3685
====> Test set loss: 99.6340


====> Epoch: 76 Average loss: 97.3265
====> Test set loss: 99.2424
====> Epoch: 77 Average loss: 97.3312
====> Test set loss: 99.3798
====> Epoch: 78 Average loss: 97.2963
====> Test set loss: 99.1872


====> Epoch: 79 Average loss: 97.2300
====> Test set loss: 99.8560
====> Epoch: 80 Average loss: 97.2512
====> Test set loss: 99.2827
====> Epoch: 81 Average loss: 97.2017
====> Test set loss: 99.4667
====> Epoch: 82 Average loss: 97.1699
====> Test set loss: 99.3549


====> Epoch: 83 Average loss: 97.1669
====> Test set loss: 99.6957
====> Epoch: 84 Average loss: 97.1362
====> Test set loss: 99.0256
====> Epoch: 85 Average loss: 97.1128
====> Test set loss: 99.2963


====> Epoch: 86 Average loss: 97.1072
====> Test set loss: 99.0395
====> Epoch: 87 Average loss: 97.0178
====> Test set loss: 99.3423
====> Epoch: 88 Average loss: 97.0713
====> Test set loss: 99.1499


====> Epoch: 89 Average loss: 97.0227
====> Test set loss: 99.1738
====> Epoch: 90 Average loss: 96.9786
====> Test set loss: 99.4335
====> Epoch: 91 Average loss: 96.9955
====> Test set loss: 99.2537
====> Epoch: 92 Average loss: 96.9646


====> Test set loss: 99.1596
====> Epoch: 93 Average loss: 96.9264
====> Test set loss: 99.1257
====> Epoch: 94 Average loss: 96.8941
====> Test set loss: 99.1499
====> Epoch: 95 Average loss: 96.9153
====> Test set loss: 99.4033


====> Epoch: 96 Average loss: 96.8800
====> Test set loss: 99.4504
====> Epoch: 97 Average loss: 96.8779
====> Test set loss: 99.1594
====> Epoch: 98 Average loss: 96.8241
====> Test set loss: 99.3363


====> Epoch: 99 Average loss: 96.8429
====> Test set loss: 99.3587
====> Epoch: 100 Average loss: 96.7721
====> Test set loss: 99.2598


In [22]:
torch.save(vae.state_dict(), os.path.join(curr_dirname, "vae.pt"))