In [60]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
import tensorflow as tf
import tensorflow_datasets as tfds
from IPython.display import clear_output
import pathlib
import os
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

os.chdir('/Users/akshay/Desktop')

Number of replicas: 1
2.13.0


In [61]:
#Config
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = 'data/train'
VAL_DIR = 'data/val'
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0
LAMBDA_CYCLE = 10 
NUM_WORKERS = 0
NUM_EPOCHS = 6
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = 'genh.pth.tar'
CHECKPOINT_GEN_Z = 'genz.pth.tar'
CHECKPOINT_CRITIC_H = 'critich.pth.tar'
CHECKPOINT_CRITIC_Z = 'criticz.pth.tar'

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2()
    ],
    additional_targets={'image0': 'image'}
)

In [62]:
class PhotoMonetDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.root_monet = root_monet
        self.root_photo = root_photo
        self.transform = transform
        
        self.monet_images = os.listdir(root_monet)
        self.photo_images = os.listdir(root_photo)
        self.length_dataset = max(len(self.photo_images), len(self.monet_images))
        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)
        
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        monet_img = self.monet_images[index % self.monet_len]
        photo_img = self.photo_images[index % self.photo_len]
        
        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)
        
        monet_img = np.array(Image.open(monet_path).convert('RGB'))
        photo_img = np.array(Image.open(photo_path).convert('RGB'))
        
        if self.transform:
            augmentations = self.transform(image=monet_img, image0=photo_img)
            monet_img = augmentations['image']
            photo_img = augmentations['image0']
            
        return monet_img, photo_img

In [63]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [64]:
#Discriminator

#Blocks are fundamental building block of ML models
#Typically consisting of 1+ neurons that work to process input data and produce output
#Tensor: A data structure. 1D tensor is vector, 2D tensor is matrix.

class Block(nn.Module): 
    def __init__(self, in_channels, out_channels, stride): #constuctor method, in_channels are number of channels in input image, out_channels denote number of channels produced by convolution
        super().__init__() #Calls constructor of parent class
        self.conv = nn.Sequential( #defines variable self.conv, which is sequential container that provide a way to build sequence of layes in neural network to organize flow of data through layers
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias = True, padding_mode = "reflect"), #Creates 2D convolutional layer, kernel size of 4, padding of 1, padding mode reflective
            nn.InstanceNorm2d(out_channels), #adds instance normalization layer, which helps to normalize activations within each channel, improving training stability and convergence
            nn.LeakyReLU(0.2, inplace=True) #Adds leaky ReLU activation function with negative slope of 0.2, allowing small negative values to pass through, introducing some non-linearity while preventing vanishing gradient problem
        )
    
    #Defines how data flows through the layers of module during forward pass (process of moving data through model from input -> output)
    def forward(self, x): 
        return self.conv(x)
    
#Discriminator - Task is to classify whether a given input image is real or fake
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features = [64, 128, 256, 512]): #Constructor method, in_channels=3 for RBG, features are number of output channels in each layer of discriminator
        super().__init__()
        self.initial = nn.Sequential( #Defines initial part of discirminator network, which consists of 2 layers
            nn.Conv2d( #2D convolutional layer, features[0] output channels
                in_channels,
                features[0],
                kernel_size = 4,
                stride = 2,
                padding = 1,
                padding_mode = "reflect"
            ),
            nn.LeakyReLU(0.2, inplace=True) #Leaky ReLU actiovation function
        )
        
        layers = [] #Used to hold subsequent layers of discriminator
        in_channels = features[0]
        for feature in features[1:]: #Iterates over remaining elements in features list, excluding first one as it was used in initial layer
            layers.append(Block(in_channels, feature, stride = 1 if feature==features[-1] else 2)) #Appends Block to layers list, used to create sequence of convolutional layers with increasing features
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect')) #Adds final 2D convolutional layer that takes output from previous layers, used for final classification
        self.model = nn.Sequential(*layers) #Creates final discriminator model, * unpacks list, pasing each element of list as seperate arguments to nn.Sequential

    #Applies inital layers to input fata followed by main layers of discriminator, output passed through torch.sigmoid to produce final output
    #Representing discriminator's confidence in classifying input data as real or fake
    def forward(self, x): 
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [65]:
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                img_channels,
                num_features,
                kernel_size=7,
                stride=1,
                padding=3,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 4,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                ),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features * 4,
                    num_features * 2,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 1,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
            ]
        )

        self.last = nn.Conv2d(
            num_features * 1,
            img_channels,
            kernel_size=7,
            stride=1,
            padding=3,
            padding_mode="reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


In [66]:
import torch
import sys

from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image


def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (monet, photo) in enumerate(loop):
        monet = monet.to(DEVICE)
        photo = photo.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_photo = gen_H(monet)
            D_H_real = disc_H(photo)
            D_H_fake = disc_H(fake_photo.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_monet = gen_Z(monet)
            D_Z_real = disc_Z(monet)
            D_Z_fake = disc_Z(fake_monet.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_photo)
            D_Z_fake = disc_Z(fake_monet)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_monet = gen_Z(fake_photo)
            cycle_photo = gen_H(fake_monet)
            cycle_monet_loss = l1(monet, cycle_monet)
            cycle_photo_loss = l1(photo, cycle_photo)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_monet = gen_Z(monet)
            identity_photo = gen_H(photo)
            identity_monet_loss = l1(monet, identity_monet)
            identity_photo_loss = l1(photo, identity_monet)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_monet_loss * LAMBDA_CYCLE
                + cycle_photo_loss * LAMBDA_CYCLE
                + identity_photo_loss * LAMBDA_IDENTITY
                + identity_monet_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_photo * 0.5 + 0.5, f"saved_images/photo_{idx}.png")
            save_image(fake_monet * 0.5 + 0.5, f"saved_images/monet_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))


def main():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_H,
            gen_H,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_Z,
            gen_Z,
            opt_gen,
            config.LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_H,
            disc_H,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_Z,
            disc_Z,
            opt_disc,
            LEARNING_RATE,
        )

    dataset = PhotoMonetDataset(
        root_photo=TRAIN_DIR + "/photos",
        root_monet=TRAIN_DIR + "/monets",
        transform=transforms,
    )
    val_dataset = PhotoMonetDataset(
        root_photo=VAL_DIR + "/photos",
        root_monet=VAL_DIR + "/monets",
        transform=transforms,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler
        )

        if SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)


if __name__ == "__main__":
    main()

100%|█████████| 1407/1407 [4:50:02<00:00, 12.37s/it, H_fake=0.329, H_real=0.828]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 1407/1407 [2:23:22<00:00,  6.11s/it, H_fake=0.16, H_real=0.933]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|█████████| 1407/1407 [2:22:13<00:00,  6.06s/it, H_fake=0.162, H_real=0.918]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|████████| 1407/1407 [2:09:28<00:00,  5.52s/it, H_fake=0.0703, H_real=0.944]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|████████| 1407/1407 [2:09:35<00:00,  5.53s/it, H_fake=0.0426, H_real=0.965]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|████████| 1407/1407 [2:09:38<00:00,  5.53s/it, H_fake=0.0227, H_real=0.983]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


In [74]:
def display_generated_samples(ds, model, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        generated_sample = model.predict(example_sample)
    
        plt.subplot(121)
        plt.title("Input image")
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')

        plt.subplot(122)
        plt.title("Generated image")
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

MONET_FILENAMES = tf.io.gfile.glob(str('/Users/akshay/Downloads/gan-getting-started/monet_jpg/*.jpg'))
PHOTO_FILENAMES = tf.io.gfile.glob(str('/Users/akshay/Downloads/gan-getting-started/photo_jpg/*.jpg'))