# DCGAN for CelebA

in this code, we will implement DCGAN to generate celebA dataset, this code is much the same as the simple GAN, just modify the Linear model to DCGAN.

in DCGAN's gen and disc, it use convolution layer instead of linear layer to map the input

## do some imports

In [3]:
import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

import numpy as np
import matplotlib.pyplot as plt

## define the model

In [None]:
class Discriminator(nn.Module):
    # input: (N, 3, 64, 64)
    # output: (N, 1, 1, 1)
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        # when feature_d == 128
        # (1, 64, 64) ->
        # (128, 32, 32) ->
        # (256, 16, 16) ->
        # (512, 8, 8) ->
        # (1024, 4, 4) ->
        # (1, 1, 1)
        self.img_channel = channels_img
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=(4, 4), stride=(2, 2), padding=1, bias=False
            ),
            nn.LeakyReLU(0.2),

            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=(4, 4), stride=(1, 1), padding=0)
            
        )
        self.sigmoid = nn.Sigmoid()

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        # classic: conv, bn, leakyrelu
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        x = self.disc(x)
        x = self.sigmoid(x)
        return x



class Gen(nn.Module):
    def __init__(self, channel_noise, img_channel, feature_d) -> None:
        # input: (N, 100, 1, 1)
        # output: (N, 1, 64, 64)
        super(Gen, self).__init__()

        self.channel_noise = channel_noise
        self.img_channel = img_channel
        self.feature_d = feature_d

        # when feature_d == 128
        # (100, 1, 1) ->
        # (1024, 4, 4) ->
        # (512, 8, 8) ->
        # (256, 16, 16) ->
        # (128, 32, 32) ->
        # (3, 64, 64)

        self.net = nn.Sequential(
            self._conv_T_block(self.channel_noise, self.feature_d * 8, 4, 1, 0),
            self._conv_T_block(self.feature_d * 8, self.feature_d * 4, 4, 2, 1),
            self._conv_T_block(self.feature_d * 4, self.feature_d * 2, 4, 2, 1),
            self._conv_T_block(self.feature_d * 2, self.feature_d, 4, 2, 1),
            nn.ConvTranspose2d(self.feature_d, self.img_channel, 4, 2, 1),
            nn.Tanh())

    
    def _conv_T_block(self, inchannel, outchannel, kernel_size, stride, padding):
       return nn.Sequential(
            nn.ConvTranspose2d(in_channels=inchannel,
                               out_channels=outchannel,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()
        )

    
    def forward(self, x):
        return self.net(x)

In [None]:
# wright init as the paper
def init_weight(m):
  if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)
  if isinstance(m, (nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 1.0, 0.02)
      nn.init.constant_(m.bias.data, 0)

## define params

In [5]:
LEARNING_RATE_GEN = 2e-4
LEARNING_RATE_DISC = 2e-4
NUM_EPOCH = 5
BATCH_SIZE = 64
CHANNEL_NOISE = 100
FEATURE_D = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
fix_noise = torch.randn(32, 100, 1, 1).to(DEVICE)

## prepare data

In [4]:
transform = transforms.Compose([
  transforms.Resize(64),
  transforms.ToTensor(),
  transforms.Normalize(mean=(0.5,), std=(0.5,)),
])

In [11]:
dataset = datasets.MNIST(root='./../../dataset/MNIST', download=False, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

SSLError: HTTPSConnectionPool(host='docs.google.com', port=443): Max retries exceeded with url: /uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:1129)')))

In [12]:
x, _ = next(iter(dataloader))

NameError: name 'dataloader' is not defined

In [None]:
plt.figure(figsize=(12, 12))
for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow(x[i].squeeze().numpy())
    plt.axis("off")

## define optim, loss ...

In [None]:
if 'gen' in globals():
  del gen
  del disc
gen = Gen(100, 1, 64).to(DEVICE)
disc = Discriminator(1, 64).to(DEVICE)

gen.apply(init_weight)
disc.apply(init_weight)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE_GEN, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))

loss = nn.BCELoss()

Gen(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(64, 1, kernel_size=

## set up the training loop

In [None]:
! rm -rf ./run
real_writer = SummaryWriter(log_dir="./run/DCGAN/real")
fake_writer = SummaryWriter(log_dir="./run/DCGAN/fake")

In [None]:
step = 0
# Original GAN Loss
for epoch in range(NUM_EPOCH):
    for batch_idx, (x, _) in enumerate(dataloader):
        x = x.to(DEVICE)
        
        gen.train()
        disc.train()
        noise = torch.randn(BATCH_SIZE, CHANNEL_NOISE, 1, 1).to(DEVICE)
        fake_gen = gen(noise)
        fake_disc = disc(fake_gen).view(-1)
        real_disc = disc(x).view(-1)


        # train discrimnator
        # make fake_gen 0, real_gen 1
        fake_disc_loss = loss(fake_disc, torch.zeros_like(fake_disc))
        real_disc_loss = loss(real_disc, torch.ones_like(real_disc))    
        disc_loss = (fake_disc_loss + real_disc_loss) / 2
        disc.zero_grad()
        disc_loss.backward(retain_graph=True)
        opt_disc.step()

        # clip network weight
        for p in disc.parameters():
          p.data.clamp_(min=-0.01, max=0.01)

        # train generator
        # make fake_gen 1
        fake_gen = gen(noise)
        fake_disc = disc(fake_gen).view(-1)
        loss_gen = loss(fake_disc, torch.ones(fake_disc))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()



        if batch_idx % 10 == 0:

          print(f"epoch:[{epoch:>2d}/{NUM_EPOCH:>2d}], batch[{batch_idx:>4d}/{len(dataloader):>4d}], loss:[G:{loss_gen:>6f},D:{disc_loss:6f}]")

          gen.eval()
          disc.eval()
          # GAN sample real and gen_result
          with torch.no_grad():
            fake_out = gen(fix_noise)
            grid_fake = torchvision.utils.make_grid(fake_out, normalize=True)
            grid_real = torchvision.utils.make_grid(x, normalize=True)
            real_writer.add_image(tag="real", img_tensor=grid_real, global_step=step)
            fake_writer.add_image("fake", grid_fake, step)

          step += 1



epoch:[ 0/ 5], batch[   0/ 938], loss:[G:2.994233,D:0.742415]
epoch:[ 0/ 5], batch[  10/ 938], loss:[G:6.528258,D:0.177330]
epoch:[ 0/ 5], batch[  20/ 938], loss:[G:7.242606,D:0.367552]
epoch:[ 0/ 5], batch[  30/ 938], loss:[G:8.267038,D:0.344901]
epoch:[ 0/ 5], batch[  40/ 938], loss:[G:7.915555,D:0.305940]
epoch:[ 0/ 5], batch[  50/ 938], loss:[G:6.081284,D:0.122923]
epoch:[ 0/ 5], batch[  60/ 938], loss:[G:6.069611,D:0.097718]
epoch:[ 0/ 5], batch[  70/ 938], loss:[G:4.182104,D:1.246040]
epoch:[ 0/ 5], batch[  80/ 938], loss:[G:4.549936,D:0.176098]
epoch:[ 0/ 5], batch[  90/ 938], loss:[G:2.798174,D:0.026363]
epoch:[ 0/ 5], batch[ 100/ 938], loss:[G:7.296311,D:0.254291]
epoch:[ 0/ 5], batch[ 110/ 938], loss:[G:3.177748,D:0.059517]
epoch:[ 0/ 5], batch[ 120/ 938], loss:[G:2.310700,D:0.075628]
epoch:[ 0/ 5], batch[ 130/ 938], loss:[G:2.896574,D:0.046743]
epoch:[ 0/ 5], batch[ 140/ 938], loss:[G:2.360650,D:0.076969]
epoch:[ 0/ 5], batch[ 150/ 938], loss:[G:1.848305,D:0.078786]
epoch:[ 

KeyboardInterrupt: ignored

In [None]:
# WGAN LOSS
for epoch in range(NUM_EPOCH):
    for batch_idx, (x, _) in enumerate(dataloader):

        gen.train()
        disc.train()

        x = x.to(DEVICE)
        # train disc for some iteration
        for disc_iter in range(5):
          noise = torch.randn((BATCH_SIZE, CHANNEL_NOISE, 1, 1)).to(DEVICE)
          fake_gen = gen(noise)
          fake_disc = disc(fake_gen).reshape(-1)
          real_disc = disc(x).reshape(-1)


          # train discrimnator
          # maximize E(real_disc) - E(fake_disc) 
          disc_loss = fake_disc.mean() - real_disc.mean()
          disc.zero_grad()
          disc_loss.backward(retain_graph=True)
          opt_disc.step()

        # clip disc's params weight
        for p in disc.parameters():
          p.data.clamp_(min=-0.01, max=0.01)

        # train generator
        # maximize E(fake_gen)
        fake_disc = disc(fake_gen).reshape(-1)
        loss_gen = - fake_disc.mean()
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 10 == 0:

          print(f"epoch:[{epoch:>2d}/{NUM_EPOCH:>2d}], batch[{batch_idx:>4d}/{len(dataloader):>4d}], loss:[G:{loss_gen:>6f},D:{disc_loss:6f}]")

          gen.eval()
          disc.eval()
          # GAN sample real and gen_result
          with torch.no_grad():
            fake_out = gen(fix_noise)
            grid_fake = torchvision.utils.make_grid(fake_out[:32], normalize=True)
            grid_real = torchvision.utils.make_grid(x, normalize=True)

            real_writer.add_image(tag="real", img_tensor=grid_real, global_step=batch_idx)
            fake_writer.add_image("fake", grid_fake, batch_idx)



epoch:[ 0/15], batch[   0/ 938], loss:[G:-0.502614,D:0.000000]
epoch:[ 0/15], batch[  10/ 938], loss:[G:-0.000002,D:-0.999994]
epoch:[ 0/15], batch[  20/ 938], loss:[G:-0.000003,D:-0.999993]
epoch:[ 0/15], batch[  30/ 938], loss:[G:-0.000008,D:-0.999992]
epoch:[ 0/15], batch[  40/ 938], loss:[G:-0.000011,D:-0.999989]
epoch:[ 0/15], batch[  50/ 938], loss:[G:-0.000011,D:-0.999990]
epoch:[ 0/15], batch[  60/ 938], loss:[G:-0.000011,D:-0.999989]
epoch:[ 0/15], batch[  70/ 938], loss:[G:-0.000009,D:-0.999991]
epoch:[ 0/15], batch[  80/ 938], loss:[G:-0.000064,D:-0.999935]
epoch:[ 0/15], batch[  90/ 938], loss:[G:-0.000024,D:-0.999976]
epoch:[ 0/15], batch[ 100/ 938], loss:[G:-0.000016,D:-0.999984]
epoch:[ 0/15], batch[ 110/ 938], loss:[G:-0.000013,D:-0.999987]
epoch:[ 0/15], batch[ 120/ 938], loss:[G:-0.000010,D:-0.999990]
epoch:[ 0/15], batch[ 130/ 938], loss:[G:-0.000009,D:-0.999991]
epoch:[ 0/15], batch[ 140/ 938], loss:[G:-0.000007,D:-0.999993]
epoch:[ 0/15], batch[ 150/ 938], loss:[G:

KeyboardInterrupt: ignored