In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 2e-4
batch_size = 128
image_size = 64
channels_img = 1
z_dim = 100
num_epochs = 31
features_disc = 64
features_gen = 64

In [3]:
transforms = torchvision.transforms.Compose([transforms.Resize(image_size),
                                             transforms.ToTensor(),
                                transforms.Normalize([0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)])])

In [4]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

In [5]:
def get_indices(dataset, class_name):
    indices = []
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == class_name:
            indices.append(i)
    return indices

In [6]:
idx = get_indices(dataset, 9)

In [7]:
loader = DataLoader(dataset, batch_size=batch_size, sampler = SubsetRandomSampler(idx))

In [8]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1),
            self._block(features_d*2, features_d*4, 4, 2, 1),
            self._block(features_d*4, features_d*8, 4, 2, 1),
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.disc(x)

In [9]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(z_dim, features_g*16, 4, 1, 0),
            self._block(features_g*16, features_g*8, 4, 2, 1),
            self._block(features_g*8, features_g*4, 4, 2, 1),
            self._block(features_g*4, features_g*2, 4, 2, 1),
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

In [10]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [11]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("Success")

In [12]:
test()

Success


In [13]:
gen = Generator(z_dim, channels_img, features_gen).to(device)
initialize_weights(gen)

In [14]:
disc = Discriminator(channels_img, features_disc).to(device)
initialize_weights(disc)

In [15]:
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

In [16]:
criterion = nn.BCELoss()

In [17]:
fixed_noise = torch.randn((32, z_dim, 1, 1)).to(device)

In [18]:
writer_real = SummaryWriter(f"runs/DCGAN_Nines/real")
writer_fake = SummaryWriter(f"runs/DCGAN_Nines/fake")

In [19]:
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [20]:
step = 0

In [21]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)

        ##train disc
        disc.zero_grad()

        fake = gen(noise)
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc_total = (loss_disc_real + loss_disc_fake) / 2
        loss_disc_total.backward(retain_graph=True)
        opt_disc.step()

        ##train gen
        gen.zero_grad()
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        loss_gen.backward()
        opt_gen.step()
        
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                Loss D: {loss_disc_total:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise)

                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )

                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1

Epoch [0/31] Batch 0/47                 Loss D: 0.6971, loss G: 0.7839
Epoch [1/31] Batch 0/47                 Loss D: 0.0542, loss G: 2.8374
Epoch [2/31] Batch 0/47                 Loss D: 0.0161, loss G: 3.9836
Epoch [3/31] Batch 0/47                 Loss D: 0.0076, loss G: 4.7184
Epoch [4/31] Batch 0/47                 Loss D: 0.0400, loss G: 3.2208
Epoch [5/31] Batch 0/47                 Loss D: 0.3626, loss G: 2.0845
Epoch [6/31] Batch 0/47                 Loss D: 0.4966, loss G: 1.9290
Epoch [7/31] Batch 0/47                 Loss D: 0.5545, loss G: 0.7807
Epoch [8/31] Batch 0/47                 Loss D: 0.6517, loss G: 0.8996
Epoch [9/31] Batch 0/47                 Loss D: 0.6399, loss G: 0.7805
Epoch [10/31] Batch 0/47                 Loss D: 0.6269, loss G: 0.8499
Epoch [11/31] Batch 0/47                 Loss D: 0.6606, loss G: 1.0748
Epoch [12/31] Batch 0/47                 Loss D: 0.6378, loss G: 0.9113
Epoch [13/31] Batch 0/47                 Loss D: 0.6883, loss G: 0.8702
Ep

In [1]:
import tensorboard

In [2]:
%load_ext tensorboard

In [23]:
torch.save(gen, "Nines_generator")

In [22]:
torch.save(disc, "Nines_discriminator")

In [3]:
%tensorboard --logdir=runs