In [42]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

SEED = 777
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [43]:
from datetime import datetime
now = datetime.now()

# Training Hyperparameters
EPOCHS = 10
BATCH_SIZE = 64
LR = 0.001

ROOT_DIR = os.getcwd()
LOG_DIR = os.path.join(ROOT_DIR, "logs", now.strftime("%Y%m%d-%H%M%S"))
LOG_ITER = 100
CKPT_DIR = os.path.join(ROOT_DIR, "checkpoints")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.exists(CKPT_DIR):
    os.makedirs(CKPT_DIR)

In [44]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# CIFAR-10 Dataset
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

# CIFAR-10 Dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)

Files already downloaded and verified


In [45]:
# CIFAR-10 Test Dataset
test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

# CIFAR-10 Test Dataloader
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
)

Files already downloaded and verified


In [46]:
# Generator for CIFAR-10
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, 3072)
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(3072)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        x = self.fc4(x)
        x = self.tanh(x)
        return x

In [47]:
# Discriminator for CIFAR-10
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(3072, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
        self.bn1 = nn.BatchNorm1d(1024)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

In [48]:
# Initialize Generator and Discriminator
generator = Generator()
discriminator = Discriminator()

In [49]:
# Loss and Optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))

# Setup Tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(LOG_DIR)

In [50]:
from tqdm import tqdm

# Training
iteration = 0
for epoch in range(EPOCHS):
    pbar = tqdm(train_loader, desc="Epoch {}".format(epoch))
    for data in pbar:
        iteration += 1
        # Get input data
        real_data = data[0].to(DEVICE)
        real_labels = torch.ones(real_data.size(0), 1).to(DEVICE)
        fake_labels = torch.zeros(real_data.size(0), 1).to(DEVICE)

        # Train Discriminator
        discriminator.zero_grad()

        # Get Generator Input Data
        noise = torch.randn(real_data.size(0), 100).to(DEVICE)
        fake_data = generator(noise)

        # Train on Real Data
        d_real_output = discriminator(real_data)
        d_real_loss = criterion(d_real_output, real_labels)
        d_real_loss.backward()

        # Train on Fake Data
        d_fake_output = discriminator(fake_data)
        d_fake_loss = criterion(d_fake_output, fake_labels)
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        d_optimizer.step()

        # Train Generator
        generator.zero_grad()

        # Train on Fake Data
        g_output = discriminator(fake_data)
        g_loss = criterion(g_output, real_labels)
        g_loss.backward()
        g_optimizer.step()

        # Log Losses
        writer.add_scalar("Discriminator Loss", d_loss, iteration)
        writer.add_scalar("Generator Loss", g_loss, iteration)

        # Log Images
        writer.add_image("Real Data", real_data[0], iteration)
        writer.add_image("Fake Data", fake_data[0], iteration)

        if iteration % LOG_ITER == 0:
            # Save Model
            torch.save(
                generator.state_dict(),
                os.path.join(CKPT_DIR, "generator-{}.pth".format(iteration)),
            )
            torch.save(
                discriminator.state_dict(),
                os.path.join(CKPT_DIR, "discriminator-{}.pth".format(iteration)),
            )

            # Log Progress
            pbar.set_description(
                "D: {:.4f} G: {:.4f}".format(d_loss.item(), g_loss.item())
            )


Epoch 0:   0%|          | 0/782 [00:07<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x32 and 3072x1024)

In [None]:
# Test
with torch.no_grad():
    for data in test_loader:
        real_data = data[0].to(DEVICE)
        noise = torch.randn(real_data.size(0), 100).to(DEVICE)
        fake_data = generator(noise)

        writer.add_image("Real Data", real_data[0], iteration)
        writer.add_image("Fake Data", fake_data[0], iteration)
        break

writer.close()

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1327f5ca0>
Traceback (most recent call last):
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1508, in __del__
    self._shutdown_workers()
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1472, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/Users/thkim/.pyenv/versions/3.9.10/lib/python3.9/selectors.py", line 416, in select
    

NameError: name 'iteration' is not defined

In [None]:
# Show Results
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def show_result(num_epoch, show=False, save=False, path=None):
    with torch.no_grad():
        for data in test_loader:
            real_data = data[0].to(DEVICE)
            noise = torch.randn(real_data.size(0), 100).to(DEVICE)
            fake_data = generator(noise)

            img_list = [real_data[0], fake_data[0]]
            name_list = ["Real Data", "Fake Data"]
            for img, name in zip(img_list, name_list):
                img = img.mul(0.5).add(0.5).cpu().numpy()
                plt.figure()