## model.py

In [1]:
import torch
import torch.nn as nn

In [2]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, features_d):
    super().__init__()
    self.discriminator = nn.Sequential(
        # Input: N * img_channels * 64 * 64
        nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),

        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        # After the _block() calls the output is 4 * 4

        # The Conv2d() makes the output 1 * 1
        nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid()
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
        # nn.BatchNorm2d(out_channels)
        nn.LeakyReLU(0.2)
    )
  
  def forward(self, x):
    return self.discriminator(x)

In [3]:
class Generator(nn.Module):
  def __init__(self, noise_channels, img_channels, features_g):
    super().__init__()

    self.generator = nn.Sequential(
        # Input: N * noise_channels * 1 * 1
        self._block(noise_channels, features_g * 16, 4, 1, 0),                  # Img: 4x4
        self._block(features_g * 16, features_g * 8, 4, 2, 1),                  # Img: 8x8
        self._block(features_g * 8, features_g * 4, 4, 2, 1),                   # Img: 16x16
        self._block(features_g * 4, features_g * 2, 4, 2, 1),                   # Img: 32x32

        nn.ConvTranspose2d(
                features_g * 2, img_channels, kernel_size=4, stride=2, padding=1
        ),
        # Output: N x img_channels x 64 x 64
        nn.Tanh(),
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels, 
            out_channels, 
            kernel_size, 
            stride, 
            padding, 
            bias=False
        ),
        # nn.BatchNorm2d(out_channels)
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.generator(x)

In [4]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [5]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  noise_dimension = 100
  x = torch.randn((N, in_channels, H, W))

  discriminator = Discriminator(in_channels, 8)
  assert discriminator(x).shape == (N, 1, 1, 1), "Discriminator test failed"

  gen = Generator(noise_dimension, in_channels, 8)
  z = torch.randn((N, noise_dimension, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

In [6]:
test()

## train.py

In [7]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [8]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 10
FEATURES_DISC = 64
FEATURES_GEN = 64

In [9]:
transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [10]:
# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(
    root="dataset/", 
    train=True, 
    transform=transform, 
    download=False
)

In [11]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
generator = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
discriminator = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

In [12]:
initialize_weights(generator)
initialize_weights(discriminator)

In [13]:
opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [14]:
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [15]:
generator.train()

Generator(
  (generator): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): ReLU()
    )
    (4): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Tanh()
  )
)

In [16]:
discriminator.train()

Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [17]:
for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = generator(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = discriminator(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = discriminator(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        discriminator.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = discriminator(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = generator(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/10] Batch 0/469                   Loss D: 0.6951, loss G: 0.7025
Epoch [0/10] Batch 100/469                   Loss D: 0.0003, loss G: 7.9549
Epoch [0/10] Batch 200/469                   Loss D: 0.0001, loss G: 9.4695
Epoch [0/10] Batch 300/469                   Loss D: 0.0001, loss G: 9.5981
Epoch [0/10] Batch 400/469                   Loss D: 0.0000, loss G: 10.2108
Epoch [1/10] Batch 0/469                   Loss D: 0.0000, loss G: 10.6337
Epoch [1/10] Batch 100/469                   Loss D: 0.0000, loss G: 11.2912
Epoch [1/10] Batch 200/469                   Loss D: 0.0000, loss G: 11.6562
Epoch [1/10] Batch 300/469                   Loss D: 0.0000, loss G: 11.6953
Epoch [1/10] Batch 400/469                   Loss D: 0.0000, loss G: 10.6343
Epoch [2/10] Batch 0/469                   Loss D: 0.0000, loss G: 12.3728
Epoch [2/10] Batch 100/469                   Loss D: 0.0000, loss G: 12.6308
Epoch [2/10] Batch 200/469                   Loss D: 0.0000, loss G: 12.6495
Epoch [2/