In [1]:
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [46]:
import pickle
import torchvision.transforms.functional as f_tf
data_path = "./data/mnist.pkl"

with open(data_path, "rb") as f:
    data_obj = pickle.load(f)

train_data = data_obj["train"]
test_data = data_obj["test"]
train_data, test_data = torch.Tensor(train_data), torch.Tensor(test_data)



In [47]:
# #Implementation details
# No spatial pooling, strided convs instead
# No FCs ontop of conv features
# BatchNorm applied to all layers except the generatr output and discriminator input
# * Batchnorm helps prevent mode collapse, but applying to every layer caused instability (per the authors)
# ReLUs in generator, tanh output; discriminator leaky ReLU (slope of 0.2) for all layers
# Adam w/ lr = 0.0002, momentum reduced to 0.5

# relu before batchnorm




class Generator(nn.Module):
    def __init__(self, in_shape,batch_size,latent_dim):
        super(Generator, self).__init__()
        self.in_shape = in_shape
        self.nin = np.prod(self.in_shape)
        self.batch_size = batch_size
        self.latent_dim = latent_dim

        self.features = in_shape[0]

        c1 = nn.ConvTranspose2d(self.latent_dim,self.features * 16,(4,4), stride=1,padding=0)
        c2 = nn.ConvTranspose2d(self.features * 16,self.features * 8,4,stride=2,padding=1)
        c3 = nn.ConvTranspose2d(self.features * 8,self.features * 4,4,stride=2,padding=1)
        c4 = nn.ConvTranspose2d(self.features * 4,self.features*2,4,stride=2,padding=1)
        c5 = nn.ConvTranspose2d(self.features*2,1,4,stride=2,padding=1)

        self.relu = nn.ReLU()
        self.act_out = nn.Tanh()

        self.enc = nn.Sequential(c1,self.relu, nn.BatchNorm2d(self.features*16), c2, self.relu,nn.BatchNorm2d(self.features*8), c3, self.relu,nn.BatchNorm2d(self.features*4), c4, self.relu,nn.BatchNorm2d(self.features*2), c5, self.act_out)

    def forward(self):
        X = torch.randn(self.batch_size,self.latent_dim).unsqueeze(-1).unsqueeze(-1)
        # print(X.shape)
        return self.enc(X)

class Discriminator(nn.Module):
    def __init__(self, in_shape, hidden_size, latent_dim):
        super(Discriminator,self).__init__()


        self.in_shape = in_shape
        self.hidden_size = hidden_size
        self.latent_dim = latent_dim
        n_f = self.hidden_size

        self.c1 = nn.Conv2d(1, n_f, 4,2,1)
        self.c2 = nn.Conv2d(n_f, n_f*2, 4,2,1)
        self.c3 = nn.Conv2d(n_f*2, n_f*4, 4,2,1)
        self.c4 = nn.Conv2d(n_f*4, n_f*8, 4,2,1)
        self.c5 = nn.Conv2d(n_f*8, n_f*16, 4,2,1)

        self.flatten = nn.Flatten(1,-1)

        self.out = nn.Linear(n_f * 16 * 2 * 2, 1)
        

        self.disc = nn.Sequential(self.c1, nn.LeakyReLU(0.2), nn.BatchNorm2d(n_f), self.c2, nn.LeakyReLU(0.2),nn.BatchNorm2d(n_f*2), self.c3, nn.LeakyReLU(0.2),nn.BatchNorm2d(n_f*4), self.c4, nn.LeakyReLU(0.2), nn.BatchNorm2d(n_f*8),self.c5, nn.LeakyReLU(0.2),nn.BatchNorm2d(n_f*16), self.flatten, self.out, nn.Sigmoid())

    def forward(self, X):
        return self.disc(X)



In [None]:
batch_size = 128
loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
gen = Generator((64,64,1), batch_size, 100)
disc = Discriminator((64,64,1), 64,100)
num_epochs = 50
loss_fn = nn.BCELoss()

optim_disc = torch.optim.Adam(disc.parameters(), lr = 0.0002, betas=[0.5, 0.99])
optim_gen = torch.optim.Adam(disc.parameters(), lr = 0.0002, betas=[0.5, 0.99])

for ep in trange(num_epochs):
    for src in loader:
        optim_gen.zero_grad()
        optim_disc.zero_grad()
        src = src.permute(0,3,1,2)
        src[src>0] = 1.
        # print(src.shape)
        real = f_tf.resize(src, (64))
        fake = gen()
        # print(fake.shape)
        disc_real = disc(real)
        disc_fake = disc(fake)
        # print(disc_real.shape)
        # print(disc_fake.shape)

        disc_real_loss = loss_fn(disc_real, torch.ones_like(disc_real))
        disc_fake_loss = loss_fn(disc_fake, torch.ones_like(disc_fake))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2
        disc_loss.backward(retain_graph=True)
        optim_disc.step()

        out = disc(fake)
        gen_loss = loss_fn(out, torch.ones_like(out))
        gen_loss.backward()
        optim_gen.step()

    torch.save({
        "epoch": ep,
        "disc_state": disc.state_dict(),
        "gen_state": gen.state_dict(),
        "disc_opt": optim_disc.state_dict(),
        "gen_opt": optim_gen.state_dict(),
        "gen_loss": gen_loss,
        "disc_loss": disc_loss
    }, f"model_ep_{ep}.pt")

        



