In [21]:
!kaggle datasets download jangedoo/utkface-new

!unzip utkface-new
!echo "crop_part1.tar.gz Extraction done"
# move folder ../utkface_aligned_cropped/UTKFace to ..
!mv utkface_aligned_cropped/crop_part1 .

!python prepare_dataset.py
!echo "Dataset folder structure created"
!mkdir old2young
!mv test old2young/
!mv train old2young/

Extraction done
Create the Directories for Dataset
No of images in young is 2323 and old is 2390
Dataset folder structure created


In [1]:
import os
import glob
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import nn
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid

In [19]:
# Define the dataset class
class CycleGANDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned
        self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))

        # Debugging prints
        print(f"Found {len(self.files_A)} images in {mode}/A")
        print(f"Found {len(self.files_B)} images in {mode}/B")

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])
        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        channels = input_shape[0]
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channels, height, width = input_shape
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
        self.model = nn.Sequential(
            *self.discriminator_block(channels, 64, normalize=False),
            *self.discriminator_block(64, 128),
            *self.discriminator_block(128, 256),
            *self.discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def discriminator_block(self, in_filters, out_filters, normalize=True):
        layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, img):
        return self.model(img)


In [14]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Warning: Trying to create an Empty buffer"
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


In [25]:
# Define the loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Training settings
n_epochs = 10
epoch = offset = 0
decay_start_epoch = 3
input_shape = (3, 200, 200)
batch_size = 1
lr = 2e-4
checkpoint_interval = 1
sample_interval = 100
lambda_cyc = 10
lambda_id = 5

# Initialize models and optimizers
G_AB = GeneratorResNet(input_shape, num_residual_blocks=3)
G_BA = GeneratorResNet(input_shape, num_residual_blocks=3)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G_AB = G_AB.to(device)
G_BA = G_BA.to(device)
D_A = D_A.to(device)
D_B = D_B.to(device)
criterion_GAN = criterion_GAN.to(device)
criterion_cycle = criterion_cycle.to(device)
criterion_identity = criterion_identity.to(device)

In [29]:
import itertools
import sys

In [26]:
if epoch != 0:
    G_AB.load_state_dict(torch.load(f"saved_models/G_AB_{epoch}.pth"))
    G_BA.load_state_dict(torch.load(f"saved_models/G_BA_{epoch}.pth"))
    D_A.load_state_dict(torch.load(f"saved_models/D_A_{epoch}.pth"))
    D_B.load_state_dict(torch.load(f"saved_models/D_B_{epoch}.pth"))
else:
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(0.5, 0.999)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=lambda epoch: 1.0 - max(0, epoch + offset - decay_start_epoch) / (n_epochs - decay_start_epoch)
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=lambda epoch: 1.0 - max(0, epoch + offset - decay_start_epoch) / (n_epochs - decay_start_epoch)
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=lambda epoch: 1.0 - max(0, epoch + offset - decay_start_epoch) / (n_epochs - decay_start_epoch)
)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

transforms_ = [
    transforms.Resize(int(input_shape[1] * 1.12), transforms.InterpolationMode.BICUBIC),
    transforms.RandomCrop((input_shape[1], input_shape[2])),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

train_dataloader = DataLoader(
    CycleGANDataset("old2young", transforms_=transforms_, unaligned=True),
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
)

val_dataloader = DataLoader(
    CycleGANDataset("old2young", transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)


Found 2223 images in train/A
Found 2290 images in train/B
Found 100 images in test/A
Found 100 images in test/B


In [None]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, f"images/old2young/{batches_done}.png", normalize=False)

In [32]:
# Training loop
for epoch in range(epoch, n_epochs):
    print(f"Starting Epoch: {epoch}")
    for i, batch in enumerate(train_dataloader):
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        valid = torch.ones((real_A.size(0), *D_A.output_shape), requires_grad=False).to(device)
        fake = torch.zeros((real_A.size(0), *D_A.output_shape), requires_grad=False).to(device)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------
        # Print log
        batches_done = epoch * len(train_dataloader) + i

        sys.stdout.write(
            f"\r[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_dataloader)}] [D loss: {loss_D.item()}] [G loss: {loss_G.item()}, adv: {loss_GAN.item()}, cycle: {loss_cycle.item()}, identity: {loss_identity.item()}]"
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), f"saved_models/G_AB_{epoch}.pth")
        torch.save(G_BA.state_dict(), f"saved_models/G_BA_{epoch}.pth")
        torch.save(D_A.state_dict(), f"saved_models/D_A_{epoch}.pth")
        torch.save(D_B.state_dict(), f"saved_models/D_B_{epoch}.pth")


Starting Epoch: 0
[Epoch 0/10] [Batch 2289/2290] [D loss: 0.0991305336356163] [G loss: 2.8886890411376953, adv: 0.4523525834083557, cycle: 0.16132834553718567, identity: 0.1646106243133545]Starting Epoch: 1
[Epoch 1/10] [Batch 2289/2290] [D loss: 0.1756768822669983] [G loss: 2.3120851516723633, adv: 0.645332932472229, cycle: 0.11620527505874634, identity: 0.10093989968299866]Starting Epoch: 2
[Epoch 2/10] [Batch 2289/2290] [D loss: 0.22763144969940186] [G loss: 1.834311842918396, adv: 0.4154587984085083, cycle: 0.08839771151542664, identity: 0.10697519779205322]Starting Epoch: 3
[Epoch 3/10] [Batch 2289/2290] [D loss: 0.17478467524051666] [G loss: 1.8819745779037476, adv: 0.43066173791885376, cycle: 0.09412887692451477, identity: 0.10200481116771698]Starting Epoch: 4
[Epoch 4/10] [Batch 2289/2290] [D loss: 0.18551208078861237] [G loss: 3.349240779876709, adv: 1.1155190467834473, cycle: 0.15197968482971191, identity: 0.142784982919693]Starting Epoch: 5
[Epoch 5/10] [Batch 2289/2290] [D 