In [1]:
import numpy as np
import os
import sys
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import glob
import random
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import albumentations as A
from albumentations.pytorch import ToTensorV2
import copy
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import random
import shutil

In [None]:
# def histogram_equalization(img):
#     img_yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
#     img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
#     img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
#     return img_output

In [None]:
"""
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
])"""

In [None]:
# def preprocess_imgs(img_path):
#     img = cv2.imread(img_path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
#     img = histogram_equalization(img)
#     img = transform(img)
    
#     return img

In [None]:
# img_path = "/kaggle/input/best-artworks-of-all-time/images/images/Vincent_van_Gogh/Vincent_van_Gogh_109.jpg"
# processed_img = preprocess_imgs(img_path)
# img = transforms.ToPILImage()(processed_img)

# **Normal Min Max Normalization**

In [None]:
# def norm_func(min_val, max_val):
#     transform = transforms.Compose([
#         transforms.Normalize((min_val,), (max_val-min_val,))
#     ])
    
#     return transform

# def transf(img, processed_img):
#     min_val = np.min(img)
#     max_val = np.max(img)
    
#     transform1 = norm_func(min_val, max_val)
#     normalized_img = transform1(processed_img)
    
#     return normalized_img

In [None]:
# output_img = np.array(transf(img, processed_img))

# plt.hist(output_img.ravel(), bins=50, density=True)
# plt.xlabel("pixel values")
# plt.ylabel("relative frequency")
# plt.title("distribution of pixels")

# **Generator**

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_activation=True, **kwargs):
        super().__init__()

        self.conv = nn.Sequential(
            
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = out_channels,
                padding_mode = "reflect",
                **kwargs
            ) if down 
            else nn.ConvTranspose2d(
                in_channels = in_channels,
                out_channels = out_channels,
                **kwargs),

            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace = True) if use_activation else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, stride = 1, padding = 1),
            ConvBlock(channels, channels, use_activation=False, kernel_size=3, stride = 1, padding = 1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    
    def __init__(self, img_channels, num_features = 64, num_residuals = 9):
        super().__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels = img_channels,
                out_channels = num_features,
                kernel_size = 7,
                stride = 1,
                padding = 3,
                padding_mode = "reflect",
            ),
            nn.ReLU(inplace = True),
        )

        self.down_blocks = nn.ModuleList(
            [
              ConvBlock(
                  in_channels = num_features,
                  out_channels = 2*num_features,
                  kernel_size = 3,
                  stride = 2,
                  padding = 1,
                  down = True,
              ),
             ConvBlock(
                    in_channels = 2*num_features,
                    out_channels = 4*num_features,
                    kernel_size = 3,
                    stride = 2,
                    padding = 1,
                    down = True,
              ),
            ]
        )

        self.residual_block = nn.Sequential(
            *[ResidualBlock(4*num_features) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
              ConvBlock(
                  in_channels = 4*num_features,
                  out_channels = 2*num_features,
                  kernel_size = 3,
                  stride = 2,
                  padding = 1,
                  output_padding = 1,
                  down = False,
              ),
             ConvBlock(
                    in_channels = 2*num_features,
                    out_channels = num_features,
                    kernel_size = 3,
                    stride = 2,
                    padding = 1,
                    output_padding = 1,
                    down = False,
              ),
            ]
        )

        self.last = nn.Conv2d(
                in_channels = num_features,
                out_channels = img_channels,
                kernel_size = 7,
                stride = 1,
                padding = 3,
                padding_mode = "reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_block(x)
        for layer in self.up_blocks:
            x = layer(x)
        x = self.last(x)
        return torch.tanh(x)

**Generator**
* 6 residual blocks for 128x128 training image
* 9 residual blocks for 256x256 or higher resolution training image
* C7s1-k -> 7x7 Convolution-InstanceNorm-ReLU layer with k-filters and stride 1
* dk -> 3x3 Convolution-InstanceNorm-ReLU layer with k-filters and stride 2
* uk-> 3x3 Convolution-InstanceNorm-ReLU layer with k-filters and stride ½ 
* Rk -> Residual block contains 3x3 Convolution layer with same no. of filters on both layers
* Network with 6 residual blocks
* C7s1-32, d64, d128, 6*(R128), u64, u32, C7s1-3
* Network with 9 residual blocks
* C7s1-32, d64, d128, 9*(R128), u64, u32, C7s1-3

![Generator](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*PVBSmRcCz9xfw-fCNi_q5g.png)

# **Discriminator**

In [3]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias = True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256,512]):
        super().__init__()
    
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = features[0],
                kernel_size = 4,
                stride = 2,
                padding = 1,
                padding_mode = "reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            stride = 1 if feature == features[-1] else 2
            layers.append(Block(in_channels, feature, stride))
            in_channels = feature
    
        layers.append(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels =  1,
                kernel_size = 4,
                stride = 1,
                padding = 1,
                padding_mode = "reflect",
            )
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        x = self.model(x)
        return torch.sigmoid(x)

**Discriminator**
* 70x70 PatchGAN
* Ck -> 4x4 Convolution-InstanceNorm-ReLU layer with k-filters and stride 2
* Discriminator Network
* C64, C128, C256, C512
* After the last layer, apply a Convolution to produce 1-dimensional output
* Do not use InstanceNorm for the first C64 layer
* Use Leaky-ReLU with slope 0.2

![Discriminator](https://miro.medium.com/v2/resize:fit:828/format:webp/1*46CddTc5JwkFW_pQb4nGZQ.png)

**config.py**

In [4]:
def get_X_domain(x_path, n, m = 200):    
    if not os.path.exists('x_images'):
        os.makedirs('x_images')
#     if not os.path.exists('test_dir'):
#         os.makedirs('test_dir')
        
    x_images = os.listdir(x_path)
    x_len = len(x_images)
    
    for i in range(n):
        img = random.choice(x_images)
        shutil.copy(x_path + "/" + img, 'x_images' + "/" + f'x_{i}')
        if i%100 == 0: print(i)

In [5]:
get_X_domain("/kaggle/input/gan-getting-started/photo_jpg", 1350)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300


In [11]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# TRAIN_DIR_X = "/kaggle/input/images-r/train_set"
TRAIN_DIR_X = "/kaggle/working/x_images"
# TRAIN_DIR_Y = "/kaggle/input/images/dataset/y_domain"
TRAIN_DIR_Y = "/kaggle/input/best-artworks-of-all-time/images/images/Vincent_van_Gogh/"
# VAL_DIR = "/kaggle/input/images/dataset/x_domain/test_dir"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 5.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 50
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN_X = "genx.pth.tar"
CHECKPOINT_GEN_Y = "geny.pth.tar"
CHECKPOINT_CRITIC_X = "criticx.pth.tar"
CHECKPOINT_CRITIC_Y = "criticy.pth.tar"
LOAD_CHECKPOINT_GEN_X = "/kaggle/input/models/genx.pth.tar"
LOAD_CHECKPOINT_GEN_Y = "/kaggle/input/models/geny.pth.tar"
LOAD_CHECKPOINT_CRITIC_X = "/kaggle/input/models/criticx.pth.tar"
LOAD_CHECKPOINT_CRITIC_Y = "/kaggle/input/models/criticy.pth.tar"

_transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

**utils.py**

In [7]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("==> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)
    
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("==> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        
def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

**dataset.py**

In [8]:
class XYDataset(Dataset):
    def __init__(self, root_y, root_x, transform=None):
        self.root_y = root_y
        self.root_x = root_x
        self.transform = transform
        
        self.y_images = os.listdir(root_y)
        self.x_images = os.listdir(root_x)
        self.length_dataset = max(len(self.y_images), len(self.x_images)) # 877 4947
        self.y_len = len(self.y_images)
        self.x_len = len(self.x_images)
        
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        y_img = self.y_images[index % self.y_len]
        x_img = self.x_images[index % self.x_len]
        
        y_path = os.path.join(self.root_y , y_img)
        x_path = os.path.join(self.root_x , x_img)
        
        y_img = np.array(Image.open(y_path).convert("RGB"))
        x_img = np.array(Image.open(x_path).convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image = y_img, image0 = x_img)
            y_img = augmentations["image"]
            x_img = augmentations["image0"]
            
        return y_img, x_img

**train.py**

In [12]:
def train_fn(disc_X, disc_Y, gen_Y, gen_X, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    X_reals = 0
    X_fakes = 0
    loop = tqdm(loader, leave = True)
    
    for idx, (y, x) in enumerate(loop):
        y = y.to(DEVICE)
        x = x.to(DEVICE)
        
        # Discriminator
        with torch.cuda.amp.autocast():
            fake_x = gen_X(y)
            D_X_real = disc_X(x)
            D_X_fake = disc_X(fake_x.detach())
            X_reals += D_X_real.mean().item()
            X_fakes += D_X_fake.mean().item()
            D_X_real_loss = mse(D_X_real, torch.ones_like(D_X_real))
            D_X_fake_loss = mse(D_X_fake, torch.zeros_like(D_X_fake))
            D_X_loss = D_X_real_loss + D_X_fake_loss
            
            fake_y = gen_Y(x)
            D_Y_real = disc_Y(y)
            D_Y_fake = disc_Y(fake_y.detach())
            D_Y_real_loss = mse(D_Y_real, torch.ones_like(D_Y_real))
            D_Y_fake_loss = mse(D_Y_fake, torch.zeros_like(D_Y_fake))
            D_Y_loss = D_Y_real_loss + D_Y_fake_loss
            
            D_loss = (D_X_loss + D_Y_loss) / 2
            
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        # Generators
        with torch.cuda.amp.autocast():
            # adversarial loss
            D_X_fake = disc_X(fake_x)
            D_Y_fake = disc_Y(fake_y)
            
            loss_G_X = mse(D_X_fake, torch.ones_like(D_X_fake))
            loss_G_Y = mse(D_Y_fake, torch.ones_like(D_Y_fake))
            
            # cycle loss
            cycle_y = gen_Y(fake_x)
            cycle_x = gen_X(fake_y)
            cycle_y_loss = l1(y, cycle_y)
            cycle_x_loss = l1(x, cycle_x)
            
            # identity loss
            identity_y = gen_Y(y)
            identity_x = gen_X(x)
            identity_y_loss = l1(y, identity_y)
            identity_x_loss = l1(x, identity_x)
            
            # add all together
            G_loss = (
                loss_G_Y 
                + loss_G_X 
                + cycle_y_loss * LAMBDA_CYCLE 
                + cycle_x_loss * LAMBDA_CYCLE
                + identity_x_loss * LAMBDA_IDENTITY 
                + identity_y_loss * LAMBDA_IDENTITY
            )
            
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % NUM_EPOCHS == 0:
            save_image(x * 0.5 + 0.5, f"saved_images/x_{idx}.png")
            save_image(fake_y * 0.5 + 0.5, f"saved_images/fake_y_{idx}.png")

        loop.set_postfix(X_real = X_reals / (idx + 1), X_fake = X_fakes / (idx + 1))
            
def main():
#     if not os.path.exists('cyclegan_test'):
#         os.makedirs('cyclegan_test')
    
#     if not os.path.exists('cyclegan_test/x'):
#         os.makedirs('cyclegan_test/x')
    
#     if not os.path.exists('cyclegan_test/y'):
#         os.makedirs('cyclegan_test/y')
        
    if not os.path.exists('saved_images'):
        os.makedirs('saved_images')
    disc_X = Discriminator(in_channels = 3).to(DEVICE)
    disc_Y = Discriminator(in_channels = 3).to(DEVICE)
    gen_Y = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_X = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_X.parameters()) + list(disc_Y.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999),
    )
    
    opt_gen = optim.Adam(
        list(gen_X.parameters()) + list(gen_Y.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999),
    )
    
    L1 = nn.L1Loss()   # CycleConsistency & Identity
    mse = nn.MSELoss() # Adversial
    
    if LOAD_MODEL:
        load_checkpoint(
            LOAD_CHECKPOINT_GEN_X,
            gen_X,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            LOAD_CHECKPOINT_GEN_Y,
            gen_Y,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            LOAD_CHECKPOINT_CRITIC_X,
            disc_X,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            LOAD_CHECKPOINT_CRITIC_Y,
            disc_Y,
            opt_disc,
            LEARNING_RATE,
        )
    
    dataset = XYDataset(
        root_x = TRAIN_DIR_X,
        root_y = TRAIN_DIR_Y,
        transform = _transforms,
    )
    
#     val_dataset = XYDataset(
#         root_x = "cyclegan_test/x",
#         root_y = "cyclegan_test/y",
#         transform = _transforms,
#     )
    
#     val_loader = DataLoader(
#         val_dataset,
#         batch_size = 1,
#         shuffle = False,
#         pin_memory = True,
#     )

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(NUM_EPOCHS):
        print(f"===> Epoch: {epoch+1}")
        train_fn(
            disc_X,
            disc_Y,
            gen_Y,
            gen_X,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
        )
        
        if SAVE_MODEL:
            save_checkpoint(gen_X, opt_gen, filename = CHECKPOINT_GEN_X)
            save_checkpoint(gen_Y, opt_gen, filename = CHECKPOINT_GEN_Y)
            save_checkpoint(disc_X, opt_disc, filename = CHECKPOINT_CRITIC_X)
            save_checkpoint(disc_Y, opt_disc, filename = CHECKPOINT_CRITIC_Y)
            



In [None]:
main()

In [2]:
! ls -lA

total 8
drwxr-xr-x 2 root root 4096 May 10 12:22 .virtual_documents
---------- 1 root root  263 May 10 12:22 __notebook_source__.ipynb
