In [None]:
from torch import optim
import os
import torchvision.utils as vutils
import numpy as np
import math
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
maindir = '/content/drive/MyDrive/GANexp'

# Arguments
BATCH_SIZE = 256
EPOCHS = 100
Z_DIM = 5
LOAD_MODEL = False
CHANNELS = 1
DB = 'MNIST' # MNIST | FashionMNIST | USPS

if DB == 'MNIST' or DB == 'FashionMNIST':
    IMAGE_SIZE = 28
elif DB == 'USPS':
    IMAGE_SIZE = 16
else:
    print("Incorrect dataset")
    exit(0)

if not IMAGE_SIZE % 4 == 0:
    print("Incompatible Image size")
    exit(0)

# Directories for storing model and output samples
model_path = os.path.join(maindir+'/model', DB)
if not os.path.exists(model_path):
    os.makedirs(model_path)
samples_path = os.path.join(maindir+'/samples', DB)
if not os.path.exists(samples_path):
    os.makedirs(samples_path)
db_path = os.path.join(maindir+'/data', DB)
if not os.path.exists(samples_path):
    os.makedirs(samples_path)


# Method for storing generated images
def generate_imgs(z, epoch=0):
    gen.eval()
    fake_imgs = gen(z)
    fake_imgs_ = vutils.make_grid(fake_imgs, normalize=True, nrow=math.ceil(BATCH_SIZE ** 0.5))
    print(fake_imgs_.shape)
    vutils.save_image(fake_imgs_, os.path.join(samples_path, 'sample_' + str(epoch) + '.png'))


# Data loaders
mean = np.array([0.5])
std = np.array([0.5])
transform = transforms.Compose([transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
                                transforms.ToTensor(),
                                transforms.Normalize(mean, std)])

if DB=='MNIST':
    dataset = datasets.MNIST(db_path, train=True, download=True, transform=transform)
elif DB=='FashionMNIST':
    dataset = datasets.FashionMNIST(db_path, train=True, download=True, transform=transform)
elif DB=='USPS':
    dataset = datasets.USPS(db_path, train=True, download=True, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)



In [None]:

# Networks
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class Generator(nn.Module):
    def __init__(self, z_dim=10, image_size=28, channels=1, conv_dim=8):
        super(Generator, self).__init__()
        self.image_size = image_size

        self.fc1 = nn.Linear(z_dim,  (self.image_size//4)*(self.image_size//4)*conv_dim*2)
        self.tconv2 = conv_block(conv_dim * 2, conv_dim, transpose=True, use_bn=True)
        self.tconv3 = conv_block(conv_dim, channels, transpose=True, use_bn=False)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = x.reshape([x.shape[0], -1, self.image_size//4, self.image_size//4])
        x = F.relu(self.tconv2(x))
        x = torch.tanh(self.tconv3(x))
        return x


class Discriminator(nn.Module):
    def __init__(self, image_size=28, channels=1, conv_dim=8):
        super(Discriminator, self).__init__()
        self.conv1 = conv_block(channels, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2, use_bn=True)
        self.fc3 = nn.Linear((image_size//4)*(image_size//4)*conv_dim*2, 1)

    def forward(self, x):
        alpha = 0.2
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = x.reshape([x.shape[0], -1])
        x = torch.sigmoid(self.fc3(x))
        return x.squeeze()


In [None]:


gen = Generator(z_dim=Z_DIM, image_size=IMAGE_SIZE, channels=CHANNELS)
dis = Discriminator(image_size=IMAGE_SIZE, channels=CHANNELS)

# Load previous model
if LOAD_MODEL:
    gen.load_state_dict(torch.load(os.path.join(model_path, 'gen.pkl')))
    dis.load_state_dict(torch.load(os.path.join(model_path, 'dis.pkl')))

# Model Summary
print("------------------Generator------------------")
print(gen)
print("------------------Discriminator------------------")
print(dis)

------------------Generator------------------
Generator(
  (fc1): Linear(in_features=5, out_features=784, bias=True)
  (tconv2): Sequential(
    (0): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tconv3): Sequential(
    (0): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)
------------------Discriminator------------------
Discriminator(
  (conv1): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (conv2): Sequential(
    (0): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc3): Linear(in_features=784, out_features=1, bias=True)
)


In [None]:


# Define Optimizers
g_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
d_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)

# Loss functions
loss_fn = nn.BCELoss()

# Fix images for viz
fixed_z = torch.randn(BATCH_SIZE, Z_DIM)

# Labels
real_label = torch.ones(BATCH_SIZE)
fake_label = torch.zeros(BATCH_SIZE)

# GPU Compatibility
is_cuda = torch.cuda.is_available()
if is_cuda:
    gen, dis = gen.cuda(), dis.cuda()
    real_label, fake_label = real_label.cuda(), fake_label.cuda()
    fixed_z = fixed_z.cuda()

total_iters = 0
max_iter = len(data_loader)

In [None]:
writerFake = SummaryWriter(f"logs/fake")
writerReal = SummaryWriter(f"logs/real")

In [None]:
EPOCHS = 4


# Training
for epoch in range(EPOCHS):
    gen.train()
    dis.train()

    for i, data in enumerate(data_loader):

        total_iters += 1

        # Loading data
        x_real, _ = data
        z_fake = torch.randn(BATCH_SIZE, Z_DIM)

        if is_cuda:
            x_real = x_real.cuda()
            z_fake = z_fake.cuda()

        # Generate fake data
        x_fake = gen(z_fake)

        # Train Discriminator
        fake_out = dis(x_fake.detach())
        real_out = dis(x_real.detach())
        d_loss = (loss_fn(fake_out, fake_label) + loss_fn(real_out, real_label)) / 2

        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()

        # Train Generator
        fake_out = dis(x_fake)
        g_loss = loss_fn(fake_out, real_label)

        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        if i % 50 == 0:
            print("Epoch: " + str(epoch + 1) + "/" + str(EPOCHS)
                  + "\titer: " + str(i) + "/" + str(max_iter)
                  + "\ttotal_iters: " + str(total_iters)
                  + "\td_loss:" + str(round(d_loss.item(), 4))
                  + "\tg_loss:" + str(round(g_loss.item(), 4))
                  )

    if (epoch+1) % 2 == 0:
        torch.save(gen.state_dict(), os.path.join(model_path, 'gen.pkl'))
        torch.save(dis.state_dict(), os.path.join(model_path, 'dis.pkl'))

        generate_imgs(fixed_z, epoch=epoch + 1)

    # if i % 4 == 0:
    #     step = prepareVisualization(epoch,
    #                                 i,
    #                                 len(data_loader),
    #                                 d_loss,
    #                                 g_loss,
    #                                 writerFake,
    #                                 writerReal,
    #                                 step)

generate_imgs(fixed_z)

Epoch: 1/4	iter: 0/234	total_iters: 1	d_loss:0.8276	g_loss:0.4827
Epoch: 1/4	iter: 50/234	total_iters: 51	d_loss:0.3035	g_loss:1.1681
Epoch: 1/4	iter: 100/234	total_iters: 101	d_loss:0.1918	g_loss:1.5606
Epoch: 1/4	iter: 150/234	total_iters: 151	d_loss:0.1326	g_loss:1.9701
Epoch: 1/4	iter: 200/234	total_iters: 201	d_loss:0.0977	g_loss:2.3606
Epoch: 2/4	iter: 0/234	total_iters: 235	d_loss:0.0833	g_loss:2.5659
Epoch: 2/4	iter: 50/234	total_iters: 285	d_loss:0.0665	g_loss:2.8057
Epoch: 2/4	iter: 100/234	total_iters: 335	d_loss:0.0581	g_loss:2.9872
Epoch: 2/4	iter: 150/234	total_iters: 385	d_loss:0.0493	g_loss:3.1526
Epoch: 2/4	iter: 200/234	total_iters: 435	d_loss:0.0396	g_loss:3.3706
torch.Size([3, 482, 482])
Epoch: 3/4	iter: 0/234	total_iters: 469	d_loss:0.0392	g_loss:3.3601
Epoch: 3/4	iter: 50/234	total_iters: 519	d_loss:0.0351	g_loss:3.5017
Epoch: 3/4	iter: 100/234	total_iters: 569	d_loss:0.0286	g_loss:3.745
Epoch: 3/4	iter: 150/234	total_iters: 619	d_loss:0.0323	g_loss:3.6451
Epoch: 

In [None]:
x_real.shape,BATCH_SIZE,x_fake.shape, z_fake.shape

(torch.Size([256, 1, 28, 28]),
 256,
 torch.Size([256, 1, 28, 28]),
 torch.Size([256, 5]))

In [None]:
# def prepareVisualization(epoch,
#                          batchIdx,
#                          loaderLen,
#                          lossD,
#                          lossG,
#                          writerFake,
#                          writerReal,
#                          step):
#     print(
#         f"Epoch [{epoch}/{Config.numEpochs}] Batch {batchIdx}/{loaderLen} \
#                               Loss DISC: {lossD:.4f}, loss GEN: {lossG:.4f}"
#     )

#     with torch.no_grad():
#         # Generate noise via Generator
#         fake = gen(z_fake).reshape(-1, 1, 28, 28)

#         # Get real data
#         data = real.reshape(-1, 1, 28, 28)

#         # Plot the grid
#         imgGridFake = torchvision.utils.make_grid(fake,
#                                                   normalize=True)
#         imgGridReal = torchvision.utils.make_grid(data,
#                                                   normalize=True)

#         writerFake.add_image("Mnist Fake Images",
#                              imgGridFake,
#                              global_step=step)
#         writerReal.add_image("Mnist Real Images",
#                              imgGridReal,
#                              global_step=step)
#         # increment step
#         step += 1

#     return step