# DCGAN
* task: image generation
* model:
    1. $G$ take random noise and output a image
    2. $D$ take a image and output a patch
* loss: naive GAN loss

## do some imports

In [16]:
import torch
from torch import nn

import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("./../utils")
from utils import BasicConv, init_weight, set_seed

set_seed(2022)

## define the model

### $G$'s model
1. gradually decreasing channel and increasing resolution
2. in the first conv block, map from noise (N, 100, 1, 1) to init image (N, 1024, 4, 4) then follow $1$

In [9]:
class Gen(nn.Module):
    """
    input: (N, 100, 1, 1)
    output: (N, 3, 64, 64)
    """
    def __init__(self, noise_channel, img_channel) -> None:
        super(Gen, self).__init__()
        self.init_conv = nn.Sequential(
            nn.ConvTranspose2d(noise_channel, 256, 4, 1, 0),
            nn.LeakyReLU(0.2, True))
        main = []
        main.append(UpConv(256, 128))
        main.append(UpConv(128, 64))
        main.append(UpConv(64, 64))
        main.append(UpConv(64, 32))
        main.append(nn.Conv2d(32, img_channel, 3, 1, 1))
        self.model = nn.Sequential(*main)
    
    def forward(self, x): 
        x = self.init_conv(x)       
        return torch.clip(self.model(x), 0, 1)

class UpConv(nn.Module):
    """
    input: (N, in_channel, H, W)
    output: (N, out_channel, 2*H, 2*W)
    """
    def __init__(self, in_channel, out_channel) -> None:
        super(UpConv, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channel, out_channel, 2, 2, 0),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.block(x)
def test_gen():
    model = Gen(100, 3)
    x = torch.rand(8, 100, 1, 1)
    print(f"model params num:{sum(p.numel() for p in model.parameters() if p.requires_grad==True)}")
    print(f"input size{x.shape}, output size: {model(x).shape}")
test_gen()

model params num:599427
input sizetorch.Size([8, 100, 1, 1]), output size: torch.Size([8, 3, 64, 64])


### $D$'s model
1. gradually increasing channel and decreasing resolution
2. in the last conv block, map from 256 channel image to 1 channel image
   

In [3]:
class Disc(nn.Module):
    """
    input: (N, 3, 64, 64)
    output: (N, 1, 8, 8)
    """
    def __init__(self, img_channel):
        super(Disc, self).__init__()
        main = []
        main .append(BasicConv(img_channel, 256, 5, 2, leaky_relu=False))
        main.append(BasicConv(256, 128, 3, 2, leaky_relu=False))
        main.append(BasicConv(128, 64, 3, 2, leaky_relu=False))
        main.append(BasicConv(64, 32, 3, 1, leaky_relu=False))
        main.append(nn.Conv2d(32, 1, 3, 1, 1))
        self.model = nn.Sequential(*main)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        return self.sigmoid(self.model(x))
def test_disc():
    model = Disc(3)
    x = torch.rand(8, 3, 64, 64)
    print(f"model params num:{sum(p.numel() for p in model.parameters() if p.requires_grad==True)}")
    print(f"input size{x.shape}, output size: {model(x).shape}")
test_disc()

model params num:407041
input sizetorch.Size([8, 3, 64, 64]), output size: torch.Size([8, 1, 8, 8])


## define params

In [5]:
l_r_gen = 2e-4
l_r_disc = 2e-4

NUM_EPOCH = 5
BATCH_SIZE = 64

noise_channel = 100
img_channel = 3

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
fix_noise = torch.randn(32, 100, 1, 1).to(DEVICE)

## prepare data

In [15]:
dataloader = None

In [4]:
transform = transforms.Compose([
  transforms.Resize(64),
  transforms.ToTensor()
])

## define optim, loss ...

In [None]:
gen = Gen(noise_channel, img_channel).to(DEVICE)
disc = Disc(img_channel).to(DEVICE)

gen.apply(init_weight)
disc.apply(init_weight)

opt_gen = optim.Adam(gen.parameters(), lr=l_r_gen, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=l_r_disc, betas=(0.5, 0.999))

loss = nn.BCELoss()

Gen(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(64, 1, kernel_size=

## set up the training loop

In [None]:
real_writer = SummaryWriter(log_dir="./run/DCGAN/real")
fake_writer = SummaryWriter(log_dir="./run/DCGAN/fake")

In [None]:
for epoch in range(NUM_EPOCH):
    loop = tqdm(dataLoader, leave=True)
    for batch_idx, (x, _) in enumerate(dataloader):
        x = x.to(DEVICE)

        gen.train()
        disc.train()
        noise = torch.randn(BATCH_SIZE, noise_channel, 1, 1).to(DEVICE)
        fake_gen = gen(noise)

        fake_disc = disc(fake_gen).view(-1)
        real_disc = disc(x).view(-1)

        # train discrimnator
        # make fake_gen 0, real_gen 1
        fake_disc_loss = loss(fake_disc, torch.zeros_like(fake_disc))
        real_disc_loss = loss(real_disc, torch.ones_like(real_disc))    
        disc_loss = (fake_disc_loss + real_disc_loss) / 2
        disc.zero_grad()
        disc_loss.backward()
        opt_disc.step()


        # train generator
        # make fake_gen 1
        fake_gen = gen(noise)
        fake_disc = disc(fake_gen).view(-1)
        loss_gen = loss(fake_disc, torch.ones(fake_disc))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()



    if batch_idx % 10 == 0:

        print(f"epoch:[{epoch:>2d}/{NUM_EPOCH:>2d}], batch[{batch_idx:>4d}/{len(dataloader):>4d}], loss:[G:{loss_gen:>6f},D:{disc_loss:6f}]")
        gen.eval()
        disc.eval()
        # GAN sample real and gen_result
        with torch.no_grad():
          fake_out = gen(fix_noise)
          grid_fake = torchvision.utils.make_grid(fake_out, normalize=True)
          grid_real = torchvision.utils.make_grid(x, normalize=True)
          real_writer.add_image(tag="real", img_tensor=grid_real, global_step=epoch)
          fake_writer.add_image("fake", grid_fake, epoch)

epoch:[ 0/ 5], batch[   0/ 938], loss:[G:2.994233,D:0.742415]
epoch:[ 0/ 5], batch[  10/ 938], loss:[G:6.528258,D:0.177330]
epoch:[ 0/ 5], batch[  20/ 938], loss:[G:7.242606,D:0.367552]
epoch:[ 0/ 5], batch[  30/ 938], loss:[G:8.267038,D:0.344901]
epoch:[ 0/ 5], batch[  40/ 938], loss:[G:7.915555,D:0.305940]
epoch:[ 0/ 5], batch[  50/ 938], loss:[G:6.081284,D:0.122923]
epoch:[ 0/ 5], batch[  60/ 938], loss:[G:6.069611,D:0.097718]
epoch:[ 0/ 5], batch[  70/ 938], loss:[G:4.182104,D:1.246040]
epoch:[ 0/ 5], batch[  80/ 938], loss:[G:4.549936,D:0.176098]
epoch:[ 0/ 5], batch[  90/ 938], loss:[G:2.798174,D:0.026363]
epoch:[ 0/ 5], batch[ 100/ 938], loss:[G:7.296311,D:0.254291]
epoch:[ 0/ 5], batch[ 110/ 938], loss:[G:3.177748,D:0.059517]
epoch:[ 0/ 5], batch[ 120/ 938], loss:[G:2.310700,D:0.075628]
epoch:[ 0/ 5], batch[ 130/ 938], loss:[G:2.896574,D:0.046743]
epoch:[ 0/ 5], batch[ 140/ 938], loss:[G:2.360650,D:0.076969]
epoch:[ 0/ 5], batch[ 150/ 938], loss:[G:1.848305,D:0.078786]
epoch:[ 

KeyboardInterrupt: ignored