# Import libraries

In [None]:
import time

import pandas as pd
import platform
import io

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm 

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torchvision

import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Constants
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
COLOR_CHANNELS = 3
N_CLASSES = 10

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [None]:
# Root directory for dataset
dataroot = '/content'

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 25

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 28

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Number of channels in the training images. For color images this is 3
nc = 3

num_classes = 10

# Load raw CIFAR-10 Dataset and Labels

In [None]:
SIZE_H = SIZE_W = 28

transform = transforms.Compose([
    transforms.Resize((SIZE_H, SIZE_W)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# image_mean = [0.485, 0.456, 0.406]
# image_std  = [0.229, 0.224, 0.225]

# transformer = transforms.Compose([
#     transforms.Resize((SIZE_H, SIZE_W)),        # scaling images to fixed size
#     transforms.ToTensor(),                      # converting to tensors
#     transforms.Normalize(image_mean, image_std) # normalize image data per-channel
# ])

In [None]:
train_dataset = CIFAR10(root='.',
                        train=True,
                        transform=transform,
                        download=True)

test_dataset  = CIFAR10(root='.',
                        train=False,
                        transform=transform,)

len(train_dataset), len(test_dataset)

In [None]:
# import torchvision.transforms as T

# # noise = torch.randn(batch_size, nz, 1, 1, device=device)
# # fake = netG(noise)
# # print (fake.shape)

# fig, axs = plt.subplots(2, 3, figsize=(20,12))

# for k in range(10):
#     tensor_img = train_dataset[k][0]
#     tensor_img = tensor_img/2 + 0.5 # unnormalize
#     img = T.ToPILImage()(tensor_img)
# #     img = img / 2. + 0.5 
#     axs[1, k%3].imshow(img)

# plt.show()

## Create DataLoader

In [None]:
batch_size = 25         # hyper-parameter 
train_loader = torch.utils.data.DataLoader(
                            dataset = train_dataset, 
                            batch_size = batch_size, 
                            shuffle = True)

test_loader = torch.utils.data.DataLoader(
                            dataset = test_dataset, 
                            batch_size = batch_size, 
                            shuffle = True)

## weights initialization

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.Linear1 = nn.Linear(nz, 128 * 7 * 7)
        self.leaky_relu = nn.LeakyReLU(0.2)

        self.convT1 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4,
                                     stride=2, padding=1, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(128)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
###########################################
        self.convT2 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4,
                                     stride=2, padding=1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
###########################################
        self.convT3 = nn.ConvTranspose2d(in_channels=128, out_channels=nc, kernel_size=1,
                                     stride=1, padding=0, bias=False)

    def forward(self, input):
        out = self.Linear1(input.squeeze())
        out = self.leaky_relu(out)
        out = out.reshape(-1,128,7,7)

        out = self.convT1(out)
        out = self.batchnorm1(out)
        out = self.leaky_relu1(out)

        out = self.convT2(out)
        out = self.batchnorm2(out)
        out = self.leaky_relu2(out)

        out = self.convT3(out)

        return torch.tanh(out)

# Create the generator
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
print(fake.shape)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(in_channels=nc, out_channels=128,
                           kernel_size=3, stride=2,
                           padding=1, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(128)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
##########################################
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=128,
                           kernel_size=3, stride=2,
                           padding=1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
#########################################
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128,
                           kernel_size=3, stride=2,
                           padding=1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.leaky_relu3 = nn.LeakyReLU(0.2)
#########################################
        self.Linear = nn.Linear(128*4*4, self.num_classes)


    def forward(self, input):

        out = self.conv1(input)
        out = self.batchnorm1(out)
        out = self.leaky_relu1(out)
        #############################
        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = self.leaky_relu2(out)
        #############################
        out = self.conv3(out)
        out = self.batchnorm3(out)
        out = self.leaky_relu3(out)
        #############################
        out = out.view(-1,128*4*4)
        out = self.Linear(out)
        xe = torch.exp(out)
        logexpsum = xe.sum(dim = -1)
        adv = logexpsum / (logexpsum + 1) # real_fake
        aux = F.softmax(out.squeeze(), dim = 1) # classification

        return adv, aux

# Create the Discriminator
num_classes = 10
netD = Discriminator(num_classes).to(device)
    
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

noise = torch.randn(batch_size, 3, 28, 28, device=device)
adv, aux = netD(noise)
print(aux.shape)

In [None]:
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
from tqdm import tqdm
from numpy.random import randint

def get_supervised_samples(n_samples=100):
    n_samples=n_samples
    n_classes=10
    n_per_class = int(n_samples / n_classes)
    d = dict()
    for i in range(n_classes):
        d[i] = n_per_class

    for idx, (X, y) in tqdm(enumerate(train_dataset)):
        if idx == 0:
            d[y] = d[y] - 1
            final = X[None, :]
            final_labels = torch.tensor(y)
        else:
            if d[y] > 0:
                d[y] = d[y] - 1
                final = torch.cat((final, X[None, :]), axis=0)
                final_labels = torch.hstack((final_labels, torch.tensor(y)))

    return final, final_labels

In [None]:
final, final_labels = get_supervised_samples(100)

In [None]:
def train_netD(real_data, labels, optimizerD):
    optimizerD.zero_grad()
    
    supervised_batch = batch_size // 2
    bce_loss = nn.BCELoss()
    ce_loss = nn.CrossEntropyLoss()

    # update supervised discriminator
    # idxs = np.random.randint(0, real_data.shape[0], supervised_batch)
    # supervised_real, supervised_targets = real_data[idxs], labels[idxs]
    supervised_real, supervised_targets = final.to(device), final_labels.to(device) #get_supervised_samples()


    # print(supervised_real.shape, supervised_targets.shape) # torch.Size([12, 1, 28, 28]) torch.Size([12])
    _, classific_label = netD(supervised_real)
    classific_real_loss = ce_loss(classific_label, supervised_targets.to(device))

    # update unsupervised discriminator (adv)
    real_fake, _ = netD(real_data)
    real_targets = torch.ones(real_data.shape[0], dtype=torch.float32).to(device)
    real_fake_loss = bce_loss(real_fake, real_targets)

    noise = torch.randn(real_data.shape[0], nz, 1, 1, device=device)
    fake = netG(noise)
    fake_real, _ = netD(fake.detach())
    fake_real_loss = bce_loss(fake_real.view(-1), fake_targets)

    # loss
    lossD = 0.5 * (real_fake_loss + fake_real_loss) + classific_real_loss
    lossD.backward()
    # Update D
    optimizerD.step()

    # Calculate discriminator accuracy

    pred = np.concatenate([classific_label.data.cpu().numpy(),
                              ], axis=0)
    gt = np.concatenate([supervised_targets.data.cpu().numpy(),
                            ], axis=0)
    d_acc = np.mean(np.argmax(pred, axis=1) == gt)

#     netD.train(False)
    with torch.no_grad():
        idxs = np.random.randint(0, real_data.shape[0], supervised_batch)
        supervised_real, supervised_targets = real_data[idxs], labels[idxs]
        _, classific_label = netD(supervised_real)

        pred = np.concatenate([classific_label.data.cpu().numpy(),
                                ], axis=0)
        gt = np.concatenate([supervised_targets.data.cpu().numpy(),
                                ], axis=0)
        d_acc_test = np.mean(np.argmax(pred, axis=1) == gt)
#     netD.train(True)

    return lossD.item(), d_acc, d_acc_test

In [None]:
def train_netG(real_targets, optimizerG):
    optimizerG.zero_grad()

    bce_loss = nn.BCELoss()

    # update generator (g)
    noise = torch.randn(real_data.shape[0], nz, 1, 1, device=device)
    fake = netG(noise)
    adv_g, _ = netD(fake)
    lossG = bce_loss(adv_g.view(-1), real_targets)
    lossG.backward()

    # Update G
    optimizerG.step()

    return lossG.item()

In [None]:
from IPython.display import clear_output
import torchvision.transforms as T
# Lists to keep track of progress
img_list = []
img_list_fixedN = []
G_losses = []
D_losses = []
iters = 0

train_accuracy = []
test_accuracy = []

sample_interval = 100
num_epochs = 10
len_train_loader = len(train_loader)

supervised_batch = batch_size // 2

netD.train(True)
netG.train(True)
t0 = time.time()
print("Starting Training Loop...")
for epoch in range(num_epochs):
    ep_train_accuracy = []
    ep_test_accuracy = []
    for i, (imgs, labels) in enumerate(train_loader):
#         clear_output(True)
        real_data = imgs.to(device)

        real_targets = torch.ones(real_data.shape[0], dtype=torch.float32).to(device)
        fake_targets = torch.zeros(real_data.shape[0], dtype=torch.float32).to(device)
        
        # update discriminator
        lossD, d_acc, d_acc_test = train_netD(real_data, labels, optimizerD)

        lossG = train_netG(real_targets, optimizerG)
        
        # Save Losses for plotting later
        G_losses.append(lossG)
        D_losses.append(lossD)
        train_accuracy.append(d_acc)
        test_accuracy.append(d_acc_test)
        
        ep_train_accuracy.append(d_acc)
        ep_test_accuracy.append(d_acc_test)
        
#         if iters % 1000 < 10:
#             fig, axs = plt.subplots(2, 3, figsize=(20,12))

#             axs[0, 0].plot(G_losses[:], label=f'ep_gen={epoch}', color='green')
#             axs[0, 0].plot(D_losses[:], label=f'ep_discr={epoch}', color='blue')
#             axs[0, 0].legend()
#             axs[0, 0].grid()

#             axs[0, 1].plot(train_accuracy[:], label=f'ep_acc={epoch}', color='black')
#             axs[0, 1].legend()
#             axs[0, 1].grid()

#             axs[0, 2].plot(test_accuracy[:], label=f'ep_acc={epoch}', color='black')
#             axs[0, 2].legend()
#             axs[0, 2].grid()

#             noise = torch.randn(real_data.shape[0], nz, 1, 1, device=device)
#             fake_img = netG(noise)

#             for k in range(10):
#                 tensor_img = fake_img[k]/2 + 0.5 
#                 img = T.ToPILImage()(tensor_img)
#                 axs[1, k%3].imshow(img)
#             plt.show()

        iters += 1
    print(f'epoch ={epoch}/{num_epochs}')
    time.time() - t0
    print(f'spent_time ={time.time() - t0}')
    print('ep_train_accuracy', np.mean(ep_train_accuracy))
    print('ep_test_accuracy', np.mean(ep_test_accuracy))
    print('----------------------')

# Losses, metrics and generated examples

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(20,12))

axs[0, 0].plot(G_losses[:], label=f'ep_gen={epoch}', color='green')
axs[0, 0].plot(D_losses[:], label=f'ep_discr={epoch}', color='blue')
axs[0, 0].legend()
axs[0, 0].grid()

axs[0, 1].plot(train_accuracy[:], label=f'ep_acc={epoch}', color='black')
axs[0, 1].legend()
axs[0, 1].grid()

axs[0, 2].plot(test_accuracy[:], label=f'ep_acc={epoch}', color='black')
axs[0, 2].legend()
axs[0, 2].grid()

noise = torch.randn(real_data.shape[0], nz, 1, 1, device=device)
fake_img = netG(noise)

for k in range(10):
    tensor_img = fake_img[k]/2 + 0.5 
    img = T.ToPILImage()(tensor_img)
    axs[1, k%3].imshow(img)
plt.show()

In [None]:
netD.train()
test_loss = 0
correct = 0

with torch.no_grad():
    for imgs, labels in test_loader:
        adv, aux = netD(imgs.to(device))
        test_loss += F.nll_loss(aux, labels.to(device), size_average=False).item()
        pred = aux.data.max(1, keepdim=True)[1]
        correct += pred.eq(labels.to(device).data.view_as(pred)).sum()

test_loss /= len(test_loader.dataset)
# test_loss.append(test_loss)
  
print(f'Avg. loss: {test_loss}, Accuracy: {correct}/{len(test_loader.dataset)} {100. * correct / len(test_loader.dataset)}')

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(20,12))


noise = torch.randn(real_data.shape[0], nz, 1, 1, device=device)
fake_img = netG(noise)

for k in range(10):
    tensor_img = fake_img[k]/2 + 0.5 
    img = T.ToPILImage()(tensor_img)
    axs[1, k%3].imshow(img)
plt.show()