## Painter - Cycle GAN

In [None]:
#our basic libraries
import os
import numpy as np
import matplotlib.pyplot as plt
# basic libraries related to torch
import torch
import torch.nn as nn
import torchvision
#libraries for actions on dataset
from torch.utils.data import Dataset, DataLoader 
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
import albumentations as A
from albumentations.pytorch import ToTensorV2 # for data aurmentaion
#optimizer funtion
import torch.optim as optim

# tqdm is for progress bar
from tqdm import tqdm  

from torchvision.utils import save_image
from PIL import Image

We want convert photos to monet style, now we have given data names as photo and monet.
Data is available in tfrec format also but as our main aim is to learn Cycle GAN here Lets use use jpg format directly for simplicity.

Lets check if we ahve GPU access and assign it a variable.
If we have access to GPU we will have "cuda" assigned to variable or it will be "cpu" other wise

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
device

## Dataset

As you can see we are using PyTorch so lets create a dataset which can return a pair of images.
As we want to create a model which can be trained to convert an **Photo to Monet** style.
(Note: Thanks to Cycle GAN we need not worry about correspondece of these two images as PixtoPix needs images which are tightly corelated to each other from different domain, but that is not the case with Cycle GAN we just want images from two different domains.)

Here I am just messing around, you can directly copy paste the dir from the kaggle input

In [None]:
main = "../input/gan-getting-started"
dirs = os.listdir(main) 
#print(type(dirs))
for dir_1 in dirs:
    if("monet" in dir_1 and "jpg" in dir_1):
        monets_dir = main + "/" + dir_1
    if("photo" in dir_1 and "jpg" in dir_1):
        photos_dir = main + "/" + dir_1
print(monets_dir, photos_dir)

Lets create a class here named as mydata, which can inherit fucntions from Dataset.
Some observations about the dataset.
1. It is already given that images are 256*256 so we dont need to do any resizing and reshaping.
2. As we already have a lot of data lets refrain from using any other augmentations/transformations.
3. We will just do one transform of making all the value to -1 to 1.

In [None]:
# I havbe written a generic function here ,
class mydata(Dataset):
    def __init__(self, img_dir_1,img_dir_2, transforms=None):
        #as we want to use the arguments received in other functions of this class using self.
        self.img_dir_1 = img_dir_1
        self.img_dir_2 = img_dir_2
        self.transforms = transforms
        #use self if you want to use it across all the function in thsi class
        self.images_1_names = os.listdir(img_dir_1)
        self.images_2_names = os.listdir(img_dir_2)
        self.len_dataset = max(len(self.images_1_names), len(self.images_2_names))
        self.len_images_1_names = len(self.images_1_names)
        self.len_images_2_names = len(self.images_2_names)
        
    def __len__(self):#whenever we call len() on any object of this class it will call this functin.
        return self.len_dataset
    
    def __getitem__(self, idx): #whenever we call this a[] using object of this class it will invoke this function.
        #no need to use self. if you done need to use that variable in any other class
        image_1_name = self.images_1_names[idx % self.len_images_1_names]
        image_2_name = self.images_2_names[idx % self.len_images_2_names]
        
        image_1_path = os.path.join(self.img_dir_1, image_1_name )
        image_2_path = os.path.join(self.img_dir_2, image_2_name )
        
        images_1 = np.array(Image.open(image_1_path).convert('RGB'))#making it np.array so that just to plot images as it
        images_2 = np.array(Image.open(image_2_path).convert('RGB'))# we dont have to use transform To Tensor
        if self.transforms:
            augmentations = self.transforms(image=images_1, image0=images_2)
            images_1 = augmentations["image"]
            images_2 = augmentations["image0"]

        return (images_1, images_2)
# As we can see image_1 corresponds to the images from img_dir_1

In [None]:
data = mydata(photos_dir, monets_dir)

In [None]:
def plot_some(data):
    figure = plt.figure(figsize = (18,18))
    rows, cols = 8, 8
    for i in range(1, rows*cols + 1):
        index = torch.randint(low = 0, high = len(data), size = (1,)).item()
        _,img = data[index] #getting images of zebras 
        if(i>=33):
            img,_ = data[index]  # getting images of zebras
        figure.add_subplot(rows, cols, i)
        plt.imshow(img.squeeze())
        plt.axis('off')
    plt.title("Monets and Photos")
    plt.show()

In [None]:

plot_some(data)
#note that these have not been applied with transform yet.

## DataLoader

Dataloader is the one who provides the data to our model in a particular manner.

In [None]:
batch_size = 1
loader = DataLoader(data, batch_size = batch_size, shuffle = False, num_workers = 4, pin_memory = True)

In [None]:
phot,mon = next(iter(loader))

In [None]:
plt.imshow(phot.squeeze())
plt.axis("off")
plt.title("Photo")

In [None]:
plt.imshow(mon.squeeze())
plt.axis("off")
plt.title("Monet")

# Discriminator

Now discriminators has some layers repeating in same manner so I will be creating a another which will contain those layers and then we can use this class in the main generator function as many time wewant

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias= True, padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,in_channels = 3, features = [64,128,256,512]): #we will be using conv blocks for all of this
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding = 1, padding_mode='reflect'),
            nn.LeakyReLU()
        )
        layers = []
        for i in range(1, len(features)):
            layers.append(Block(features[i-1], features[i], stride =1 if i == len(features)-1 else 2))
        layers.append(nn.Conv2d(features[len(features)-1],1,  kernel_size = 4 , stride = 1, padding =1 , padding_mode='reflect' ))
        self.model = nn.Sequential(*layers)
        
    def forward(self,x):
        x = self.initial(x)
        x = self.model(x)
        sig = nn.Sigmoid()
        return sig(x)

# Generator
Coming to the most important part of the model , the generator.

In [None]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels, out_channels, down =  True, use_act =  True , **kwargs):#some other keyword arguments like stride, padding etc.
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode = "reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace = True) if use_act else nn.Identity()
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size = 3, padding = 1),
            ConvBlock(channels, channels, use_act = False, kernel_size = 3, padding =1 ),
        )
    def forward(self, x):
        return x + self.block(x) # we are not changin in output respected to x

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channels,num_feature = 64,  num_residuals = 9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_feature, kernel_size  = 7, stride = 1, padding = 3 , padding_mode = "reflect"),
            nn.ReLU(inplace =True)
        )
        # down sampling
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_feature, num_feature*2, kernel_size = 3, stride = 2, padding =1),
                ConvBlock(num_feature*2, num_feature*4, kernel_size = 3, stride = 2 , padding = 1),   
            ]
        )
        # does not much change the input       
        self.resblocks = nn.Sequential(
            *[ResBlock(num_feature*4) for _ in range(num_residuals)]
        )
        # up sampling
        self.up_blocks = nn.ModuleList(
        [
           ConvBlock(num_feature*4, num_feature*2,down=False,  kernel_size = 3, stride = 2, padding =1, output_padding = 1),
           ConvBlock(num_feature*2, num_feature, down=False, kernel_size = 3, stride = 2 , padding = 1, output_padding = 1),
        ])
        # converting it ot RGB
        self.last = nn.Conv2d(num_feature*1, img_channels,kernel_size = 7, stride = 1, padding  =  3, padding_mode = "reflect")
        
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.resblocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

batch_size = 1
learning_rate = 1e-5
lambda_identity = 0.5
lambda_cycle = 10
num_workers = 4
num_epochs = 2
load_model = False
save_model = True
checkpoint_g_photo = "../input/weights1/g_photo.pth.tar"
checkpoint_g_monet = "../input/weights1/g_monet.pth.tar"
checkpoint_d_photo = "../input/weights1/d_photo.pth.tar"
checkpoint_d_monet = "../input/weights1/d_monet.pth.tar"

transformer = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

In [None]:
import copy

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
! mkdir new_photos
! mkdir new_monets

# Training Function

In [None]:
#photos to monet
def train(d_photo,d_monet, g_photo, g_monet, d_optim, g_optim, l1, mse, loader, g_scaler, d_scaler, i):
    loop = tqdm(loader, leave = True)
    
    for idx, (photo, monet) in enumerate(loop):
        #to use gpu or  cpu data should be on same device
        monet = monet.to(device)
        photo = photo.to(device)
        
        with torch.cuda.amp.autocast():
            
            #training d_photo to claddify a real ans fake photo(transfomrd to photo from monet)
            real_photo = photo
            fake_photo = g_photo(monet)
    
            d_photo_real = d_photo(real_photo)
            d_photo_fake = d_photo(fake_photo.detach())
            
            
            d_photo_real_loss = mse(d_photo_real, torch.ones_like(d_photo_real))
            d_photo_fake_loss = mse(d_photo_fake, torch.zeros_like(d_photo_fake))
            
            d_photo_loss = d_photo_real_loss + d_photo_real_loss
            #training d_monet to classify a real and fake monet
            real_monet = monet
            fake_monet = g_monet(photo)
            
            d_monet_real  = d_monet(real_monet)
            d_monet_fake = d_monet(fake_monet.detach())
            
            d_monet_real_loss = mse(d_monet_real, torch.ones_like(d_monet_real))
            d_monet_fake_loss = mse(d_monet_fake, torch.zeros_like(d_monet_fake))
            
            d_monet_loss = d_monet_fake_loss + d_monet_real_loss
            
            #puttinh them together
            d_loss = (d_photo_loss + d_monet_loss)/2
        
        d_optim.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(d_optim)
        d_scaler.update()
        
        #training the generator
        with torch.cuda.amp.autocast():
            #adversarial loss
            d_photo_fake = d_photo(fake_photo)
            d_monet_fake = d_monet(fake_monet)
            g_photo_loss = mse(d_photo_fake, torch.ones_like(d_photo_fake))
            g_monet_loss = mse(d_monet_fake, torch.ones_like(d_monet_fake))
            
            #cycle loss
            cycle_monet = g_monet(fake_photo)
            cycle_photo = g_photo(fake_monet)
            
            cycle_monet_loss = l1(real_monet, cycle_monet)
            cycle_photo_loss = l1(real_photo, cycle_photo)
            
            #identity loss
            identity_monet = g_monet(real_monet)
            identity_photo =g_photo(real_photo)
            identity_monet_loss = l1(real_monet, identity_monet)
            identity_photo_loss = l1(real_photo, identity_photo)
            
            #let put all loss togther
            g_loss = (
                g_monet_loss + g_photo_loss +
                cycle_monet_loss*lambda_cycle + cycle_photo_loss*lambda_cycle +
                identity_monet_loss*lambda_identity + identity_photo_loss*lambda_identity 
                
            )
            
        g_optim.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(g_optim)
        g_scaler.update()
        
        if idx%100 == 0 :
            save_image(fake_photo*0.5 + 0.5, f"./new_photos/photos_{i}_{idx}.jpg" )
            save_image(fake_monet*0.5 + 0.5, f"./new_monets/monets_{i}_{idx}.jpg" )

In [None]:
def main():
    #defining d and g s
    d_photo = Discriminator(in_channels = 3).to(device)
    d_monet = Discriminator(in_channels = 3).to(device)
    g_photo = Generator(img_channels = 3).to(device)
    g_monet = Generator(img_channels = 3).to(device)
    #defining optimizers
    d_optim = optim.Adam( list(d_photo.parameters()) + list(d_monet.parameters()), lr = learning_rate, betas = (0.5, 0.999),  )       
    g_optim = optim.Adam( list(g_photo.parameters()) + list(g_monet.parameters()), lr = learning_rate, betas = (0.5, 0.999),  )
    #defining loss
    l1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    data_n = mydata(photos_dir, monets_dir, transforms = transformer)
    loader = DataLoader(data_n, batch_size = batch_size, shuffle = False, num_workers = 4, pin_memory = True)
    
    g_scaler = torch.cuda.amp.GradScaler()  # to run the program in float 16
    d_scaler = torch.cuda.amp.GradScaler()
    
    if load_model:
        load_checkpoint(checkpoint_g_photo, g_photo, g_optim, learning_rate,)
        load_checkpoint(checkpoint_g_monet, g_monet, g_optim, learning_rate,)
        load_checkpoint(checkpoint_d_photo, d_photo, d_optim, learning_rate,)
        load_checkpoint(checkpoint_d_monet, d_monet, d_optim, learning_rate,)

    for epoch in range(num_epochs):
        train(d_photo, d_monet, g_photo, g_monet, d_optim, g_optim, l1, mse , loader, g_scaler, d_scaler, epoch)
        
        if save_model:
            save_checkpoint(g_photo, g_optim, filename=checkpoint_g_photo)
            save_checkpoint(g_monet, g_optim, filename=checkpoint_g_monet)
            save_checkpoint(d_photo, d_optim, filename=checkpoint_d_photo)
            save_checkpoint(d_monet, d_optim, filename=checkpoint_d_monet)


In [None]:
# if __name__ == "__main__":
#     main()

In [None]:
g_photo = Generator(img_channels = 3).to(device)
g_monet = Generator(img_channels = 3).to(device)
optimizer = g_optim = optim.Adam( list(g_photo.parameters()) + list(g_monet.parameters()), lr = learning_rate, betas = (0.5, 0.999),  )
# checkpoint_g_monet = "../input/weights-painter/g_monet.pth.tar"
checkpoint = torch.load(checkpoint_g_monet)
g_monet.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

In [None]:
print(g_monet)

In [None]:
data_n = mydata(photos_dir, monets_dir, transforms = transformer)

In [None]:
! mkdir ../images

In [None]:
cd ../images


In [None]:
for i in range(len(data)):
    photo_initial,_ =  data_n[i]
    photo_initial = photo_initial.to(device)
    photo_initial = photo_initial.unsqueeze(0)
    monet_got = g_monet(photo_initial)
    save_image(monet_got*0.5 + 0.5, f"../images/monets_{i}.jpg" )

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")