In [1]:
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim
import torch.nn as nn
import torch
import random
import logging

In [2]:
root_logger = logging.getLogger()
G_losses = []
D_losses = []

settings = dict()
store = dict()

settings = {
    "data_root": "handwriting_cvl_data/words/",
    "manual_seed": 42,
    "nc": 1,
    "nz": 100,
    "ngf": 64,
    "ndf": 64,
    "num_workers": 32,
    "batch_size": 128,
    "num_epochs": 5,
    "lr": 0.0002,
    "beta1": 0.5,
    "image_size": (64, 64),
}

In [3]:
class ImageFolderWithClasses(ImageFolder):
    def __getitem__(self, idx):
        item = super().__getitem__(idx)
        paths = self.imgs[idx][0].split("/")
        filename = paths[len(paths) - 1].split(".")[0]
        writer_id, page_num, line_num, word_num, word = filename.split("-", 4)
        return (item + (word, writer_id,))

In [None]:
def setup(initial=False):
    if initial:
        root_logger.setLevel(logging.INFO)
        root_logger.addHandler(logging.FileHandler("Network.log", "w", "utf-8"))

    if "manual_seed" in settings:
        manual_seed = settings["manual_seed"]
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)
        print(f"[setup] Set manual seed to: {manual_seed}")

setup(True)

In [None]:
def load_dataset():
    data_root = settings["data_root"]
    batch_size = settings["batch_size"]
    image_size = settings["image_size"]
    num_workers = settings["num_workers"]

    dataset = ImageFolderWithClasses(root=data_root, transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),
    ]))

    words = set()
    writers = set()
    for i in range(len(dataset)):
        words.add(dataset[i][2])
        writers.add(dataset[i][3])

    words = sorted(list(words))
    word_to_idx = { word: i for i, word in enumerate(words) }

    writers = sorted(list(writers))
    writer_to_idx = { writer: i for i, writer in enumerate(writers) }

    store["dataloader"] = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    store["dataset"] = dataset

    store["nwords"] = len(words)
    store["nwriters"] = len(writers)

    store["word_to_idx"] = word_to_idx
    store["writer_to_idx"] = writer_to_idx
    store["idx_to_word"] = { val: key for key, val in store["word_to_idx"].items() }
    store["idx_to_class"] = { val: key for key, val in store["writer_to_idx"].items() }
    store["device"] = torch.device("cpu")

setup()
load_dataset()

In [6]:
def init_weights(m):
    m_type = type(m)
    if m_type in [nn.ConvTranspose2d, nn.Conv2d]:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif m_type in [nn.BatchNorm2d]:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
nc = settings["nc"]
nz = settings["nz"]
ngf = settings["ngf"]

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, image_input):
        return self.main(image_input)

In [None]:
def construct_generator():
    store["netG"] = Generator().to(store["device"])
    store["netG"].apply(init_weights)
    print(store["netG"])

setup()
construct_generator()

In [9]:
nc = settings["nc"]
nz = settings["nz"]
ndf = settings["ndf"]

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, image_input):
        return self.main(image_input)

In [None]:
def construct_discriminator():
    store["netD"] = Discriminator().to(store["device"])
    store["netD"].apply(init_weights)
    print(store["netD"])

setup()
construct_discriminator()

In [None]:
def train():
    dataloader = store["dataloader"]
    device = store["device"]
    netD = store["netD"]
    netG = store["netG"]

    num_epochs = settings["num_epochs"]
    optimizerD = optim.Adam(netD.parameters(), lr=settings["lr"], betas=(settings["beta1"], 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=settings["lr"], betas=(settings["beta1"], 0.999))
    criterion = nn.BCELoss()
    real_label = 1.0
    fake_label = 0.0

    def train_discriminator(data):
        netD.zero_grad()
        real_data = data[0].to(device)
        b_size = real_data.size(0)

        # discriminate real image data
        real_output = netD(real_data)
        real_output = real_output.view(-1)

        # calculate error using tensor of real_labels
        # the goal is to get the discriminator to detect real instances
        real_labels = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        errD_real = criterion(real_output, real_labels)

        # calculate gradients for D in backward pass
        errD_real.backward()

        # generate fake data
        # and detach gradients from the output because we're training the discriminator
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        generated_data = netG(noise).detach()

        # discriminate fake image data
        fake_output = netD(generated_data).view(-1)

        # calculate error using tensor of fake_labels
        # the goal is to get discriminator to detect fake instances
        fake_labels = torch.full((b_size,), fake_label, dtype=torch.float, device=device)
        errD_fake = criterion(fake_output, fake_labels)

        # calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()

        # return loss and discriminator mean 
        optimizerD.step()
        return errD_real + errD_fake

    def train_generator(data):
        netG.zero_grad()
        real_data = data[0].to(device)
        b_size = real_data.size(0)

        # generate fake data
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        generated = netG(noise)

        # discriminate fake image data
        fake_output = netD(generated).view(-1)

        # calculate error using tensor of real_labels
        # the goal is to get the generator to generate "real" instances
        real_labels = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        errG = criterion(fake_output, real_labels)

        # calculate gradients for G
        errG.backward()

        # return loss and discriminator mean 
        optimizerG.step()
        return errG

    root_logger.info("Starting training loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            errD = train_discriminator(data)
            errG = train_generator(data)

            D_losses.append(errD.item())
            G_losses.append(errG.item())

            if i % 10 == 0:
                epoch_progress = f"{epoch}/{num_epochs} epochs"
                iteration_progress = f"{i}/{len(dataloader)} iterations"
                root_logger.info(f"{epoch_progress}, {iteration_progress}, Loss_D: {errD.item()}, Loss_G: {errG.item()}")

    torch.save(netG.state_dict(), "models/Network.G.pth")
    torch.save(netD.state_dict(), "models/Network.D.pth")

    root_logger.info("training complete")
    root_logger.info("models saved")

setup()
train()

In [None]:
def plot():
    dataloader = store["dataloader"]
    device = store["device"]
    netG = store["netG"]
    nz = settings["nz"]

    plt.figure(figsize=(10, 5))
    plt.title("Iterations vs. Loss")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.figure(figsize=(15, 15))
    real_images = next(iter(dataloader))
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(vutils.make_grid(real_images[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    fake_images = netG(fixed_noise).detach().cpu()
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True).cpu(), (1, 2, 0)))
    plt.show()

plot()