In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import random_split

torch.cuda.empty_cache()

# setup TensorBoard

In [8]:
from torch.utils.tensorboard import SummaryWriter
import datetime

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = SummaryWriter(f'runs/mnist_gan_exp_{timestamp}')

## Dataset


In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])
# data = datasets.MNIST(root='./dataset', download=True, transform=transform)
# <https://stackoverflow.com/questions/70896841/error-downloading-celeba-dataset-using-torchvision>
data = datasets.CelebA(root='./dataset', download=True, transform=transform)
size = 10000
data, _ = random_split(data, [size, len(data) - size])
batch_size: int = 256
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
n_epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Files already downloaded and verified
cuda:0


# Model

In [10]:
class Generator(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super(Generator, self).__init__()
        sizes = [in_dim, 256, 512, 1024, out_dim]
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        slope = 0.2
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.tanh(self.fc4(x))
        return x


In [11]:
class Discriminator(nn.Module):
    def __init__(self, in_dim: int):
        super(Discriminator, self).__init__()
        sizes = [in_dim, 1024, 512, 256, 1]
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # leaky relu slope
        slope = 0.2
        # dropout rate
        dropout = 0.3
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.dropout(x, dropout)
        x = F.sigmoid(self.fc4(x))
        return x

# Train

In [12]:
def train_discriminator(
    generator: Generator, 
    discriminator: Discriminator, 
    optim_discriminator, 
    x: torch.Tensor,
    criterion,
    batch_size: int,
    step: int,
):
    discriminator.zero_grad()
    x_real, y_real = x.view(-1, out_dim).to(device), torch.ones(batch_size, 1).to(device)
    x_real, y_real = Variable(x_real), Variable(y_real)
    
    d_output = discriminator(x_real)
    # print(f'd_output: {d_output.shape}, y_real: {y_real.shape}')
    loss_real = criterion(d_output, y_real)

    # train discriminator with fake data
    z = Variable(torch.randn(batch_size, in_dim, device = device))
    x_fake, y_fake = generator(z), torch.zeros(batch_size, 1).to(device)
    
    d_output = discriminator(x_fake)
    loss_fake = criterion(d_output, y_fake)

    loss = loss_real + loss_fake
    loss.backward()
    optim_discriminator.step()
    l = loss.item()
    del loss
    writer.add_scalar('loss/discriminator', l, step)
    
def train_generator(
    generator: Generator,
    discriminator: Discriminator,
    optim_generator,
    criterion,
    batch_size,
    step: int,
):
    generator.zero_grad()
    z = Variable(torch.randn(batch_size, in_dim, device = device))
    y = Variable(torch.ones(batch_size, 1).to(device))

    g_output = generator(z)
    d_output = discriminator(g_output)
    loss_generator = criterion(d_output, y)

    loss_generator.backward()
    optim_generator.step()
    loss = loss_generator.item()
    writer.add_scalar('loss/generator', loss, step)
    del loss_generator


def train(
    writer: SummaryWriter,
    generator: Generator, 
    discriminator: Discriminator, 
    dataloader: DataLoader, 
    n_epochs: int, 
):
    criterion = nn.BCELoss()
    lr = 0.0001
    optim_generator = optim.AdamW(generator.parameters(), lr=lr)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=lr)

    for epoch in range(n_epochs):
        print(f'[{datetime.datetime.now()}] Epoch {epoch}')
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            print(f'[{datetime.datetime.now()}] Epoch {epoch} batch {i}')
            step = epoch * len(dataloader) + i
            train_discriminator(generator, discriminator, optim_discriminator, x, criterion, batch_size, step)
            train_generator(generator, discriminator, optim_generator, criterion, batch_size, step)

        # generate image
        with torch.no_grad():
            z = torch.randn(batch_size, in_dim).to(device)
            images = generator(z).view(-1, *shape)
            writer.add_images(f'generated_image', images, epoch)

        if epoch % 1 == 0:
            torch.save(generator.state_dict(), f'models/generator.pt')
            torch.save(discriminator.state_dict(), f'models/discriminator.pt')

# (channel, width, height)
shape = iter(dataloader).next()[0].shape[1:]
print(f'Shape: {shape}')
in_dim = 100
out_dim = shape[0] * shape[1] * shape[2]
generator = Generator(in_dim, out_dim).to(device)
discriminator = Discriminator(out_dim).to(device)
train(writer, generator, discriminator, dataloader, n_epochs = n_epochs)

[2022-06-11 21:17:28.010473] Epoch 0
[2022-06-11 21:17:28.349190] Epoch 0 batch 0
[2022-06-11 21:17:28.978848] Epoch 0 batch 1
[2022-06-11 21:17:29.594143] Epoch 0 batch 2
[2022-06-11 21:17:35.513948] Epoch 1
[2022-06-11 21:17:37.076653] Epoch 1 batch 0
[2022-06-11 21:17:37.691240] Epoch 1 batch 1
[2022-06-11 21:17:38.296946] Epoch 1 batch 2
[2022-06-11 21:17:44.813323] Epoch 2
[2022-06-11 21:17:46.692300] Epoch 2 batch 0
[2022-06-11 21:17:47.553536] Epoch 2 batch 1
[2022-06-11 21:17:48.459173] Epoch 2 batch 2
[2022-06-11 21:17:54.814449] Epoch 3
[2022-06-11 21:17:56.822811] Epoch 3 batch 0
[2022-06-11 21:17:57.438916] Epoch 3 batch 1
[2022-06-11 21:17:58.041238] Epoch 3 batch 2
[2022-06-11 21:18:04.398869] Epoch 4
[2022-06-11 21:18:06.518738] Epoch 4 batch 0
[2022-06-11 21:18:07.343023] Epoch 4 batch 1
[2022-06-11 21:18:08.121174] Epoch 4 batch 2
[2022-06-11 21:18:14.445384] Epoch 5
[2022-06-11 21:18:16.045917] Epoch 5 batch 0
[2022-06-11 21:18:16.709247] Epoch 5 batch 1
[2022-06-11 2