# GANs for Image Generation tasks

## 2. Conditional GANs - AC-GAN

### Prepare DataLoader for MNIST dataset

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

# fix manual seed.
torch.manual_seed(1234)

# set batch size.
BATCH_SIZE = 256

# prepare dataloader.
tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = MNIST(root='./datasets', train=True, download=True, transform=tf)
loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

### Define GANs Models

#### Define Generator

In [None]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # ============================================================ #
        # TODO : Fill fully connected layers upconvolution layers
        # * Specificl details of model architectures are on the slides. 
        # * Hint : Use following functions : 
        #   nn.Linear(), nn.BatchNorm1d(), nn.ConvTranspose2d(), 
        #   nn.BatchNorm2d(), nn.ReLU()
        # ============================================================ #

        self.z_dim = 64
        self.hidden_dim = 256
        self.img_dim = 28 * 28

        self.fc = nn.Sequential(
            # Fill here.   
        )
        
        self.upconv = nn.Sequential(
            # Fill here.   
        )
        
    def forward(self, x):
        # ============================================================ #
        # TODO : Complete forward function. 
        # * Hint : Use self.fc and self.upconv defined above
        # ============================================================ #

        # Fill here. 
        
        return out



#### Define Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # ============================================================ #
        # TODO : Fill convolution layers and fully connected layers.
        # * Specificl details of model architectures are on the slides. 
        # * Hint : Use following functions : 
        #   nn.Conv2d(), nn.LeaklyReLU(), nn.Linear(), nn.BatchNorm1d(),
        #   nn.Sigmoid()
        # ============================================================ #
        
        self.img_dim = 28 * 28
        self.hidden_dim = 256
        self.num_class = 10

        self.conv = nn.Sequential(
            # Fill here.   

        )
        self.fc = nn.Sequential(
            # Fill here.   

        )
        self.fc_disc = nn.Sequential(
            # Fill here.   

        )
        
        self.fc_cls = nn.Sequential(
            # Fill here.   
        )

    def forward(self, x):
        # ============================================================ #
        # TODO : Complete forward function. 
        # * Hint : Use self.fc and self.upconv defined above
        #    - out_disc : head for real/fake discrimination 
        #    - out_cls : head for classification
        # ============================================================ #

        # Fill here.   

        return out_disc, out_cls


#### Prepare GAN model and Optimizers

In [None]:
# weight initialization function. 
def weights_init(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

# define GAN model.
G = Generator().cuda()
D = Discriminator().cuda()

# weight initialization & set both modes to train mode.
G.apply(weights_init)
D.apply(weights_init)

# define optimizer. Here we use Adam optimizer. 
optimizer_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))


#### Start training GAN

In [None]:
# install tensorboardx to use tensorboard.
%pip install tensorboardx

from tensorboardX import SummaryWriter
from torchvision.utils import make_grid

# Hyper-parameters. 
# ====== You don't need to change here ===== #
EPOCHS = 50
Z_DIM = 64
NUM_CLASS = 10
# ========================================== #

# logger for tensorboard.
logger = SummaryWriter()

# Fixed latent variable z, label y for visualization. 
FIXED_Z = torch.randn(size=(100,Z_DIM)).cuda()
FIXED_Y = torch.arange(10).repeat(10)
FIXED_Y = torch.zeros(size=(100,NUM_CLASS)).scatter_(1, FIXED_Y.unsqueeze(1), 1).cuda()

# GT labels for calculating binary cross entropy loss. 
real_label = torch.ones(size=(BATCH_SIZE,1)).cuda()
fake_label = torch.zeros(size=(BATCH_SIZE,1)).cuda()

# criterion for binary cross entropy loss
BCE_criterion = torch.nn.BCELoss()
CE_criterion = torch.nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    # Set both models to train modes.
    G.train()
    D.train()

    # For logging in tensorboard
    loss_G_total, loss_D_total = 0., 0.

    for batch_idx, (data, label) in enumerate(loader):
        data = data.cuda()
        label = label.cuda()
        
        # ============================================================ #
        # TODO : Fill the part for updating D&G.
        # First sample z and y. 
        # z : (BATCH_SIZE, Z_DIM) size random latent variable
        # y : (BATCH_SIZE, NUM_CLASS) size random label
        # Then Calculate GAN loss (loss_D, loss_G)
        # * Don't forget, you should also consider classification loss!!!     
        # ============================================================ #

        # ================= Update D ================== # 
                       
        # Fill here. 
        # First compute loss_D 
        # Then update the network with loss_D using optimizer_D

        # ================= Update G ================== # 
 
        # Fill here. 
        # First compute loss_G 
        # Then update the network with loss_G using optimizer_G.
        # Note that we need additional auxiliary classification loss.

        loss_D_total += loss_D.item()
        loss_G_total += loss_G.item()
        
        # print current states
        if batch_idx % 100  == 0:
            print('Epoch : {} || {}/{} || loss_G={:.3f} loss_D={:.3f}'.format(
                epoch, batch_idx, len(loader), loss_G.item(), loss_D.item()
            ))

    loss_G_total /= len(loader)
    loss_D_total /= len(loader)

    # ================= Genearte example samples ================== # 
    fake_img = G(FIXED_Z, FIXED_Y)
    fake_img = fake_img.view(fake_img.shape[0], 1, 28, 28)
    fake_img = (fake_img + 1)*0.5
    fake_img = make_grid(fake_img, nrow=10)

    
    # ============================================================ #
    # TODO : Logging on the tensorboard
    # * log loss_G_total, loss_D_total, and fake_img
    # * use logger.add_scalar() and logger.add_image() for logging
    # ============================================================ #

    # Fill here
        
    # print current states
    print('Epoch : {} has done. AVG loss : loss_G={:.3f} loss_D={:.3f}'.format(
        epoch, loss_G_total, loss_D_total
    ))


In [None]:
# Check Tensorboard.
%ls runs
%load_ext tensorboard
%tensorboard --logdir runs --port 9999