CycleGAN base on <a href="https://www.kaggle.com/code/amyjang/monet-cyclegan-tutorial/notebook">Amy Jang's notebook</a>. Tried to replicated the same UNet in pytorch.
Training is **super** slow... Had to cap the number of input pics... And it doesn't work correctly (look at version 13). This is my first submission here so any help would be appreciated...

# Importing stuff


In [None]:
import random
import os

import torch
#import torch_xla
#import torch_xla.core.xla_model as xm
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

In [None]:
print(torch.__version__)

# Defining input/output shapes and device

In [None]:
NUM_CHANNELS = 3
OUTPUT_CHANNELS = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = xm.xla_device()

# Downsampling unit

In [None]:
class Dsample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, normalize=True):
        super(Dsample, self).__init__()
        self.normalize = normalize

        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1, stride=2, bias=False)

        if self.normalize:
            self.norm = nn.InstanceNorm2d(out_channels, affine=True)
        
        self.activation = nn.LeakyReLU()
        self._weight_init()
    
    def _weight_init(self):
        for layer in self.children():
           for name, param in layer.named_parameters(recurse=False):
                if name == 'bias':
                    layer.bias = torch.nn.init.zeros_(layer.bias)
                else:
                    layer.weight = torch.nn.init.normal_(layer.weight, mean=0.0, std=0.02)
        return
    
    def forward(self, x):
        x = self.conv(x)
        if self.normalize:
            x = self.norm(x)
        return self.activation(x)

# Upsampling unit

In [None]:
class Usample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dropout=False):
        super(Usample, self).__init__()
        self.dropout = dropout

        self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1, stride=2, bias=False)

        self.norm = nn.InstanceNorm2d(out_channels, affine=True)
        
        if self.dropout:
            #self.drop = nn.Dropout2d(0.5)
            self.drop = nn.Dropout(0.5)
        
        self.activation = nn.ReLU()
        self._weight_init()
    
    def _weight_init(self):
        for layer in self.children():
           for name, param in layer.named_parameters(recurse=False):
                if name == 'bias':
                    layer.bias = torch.nn.init.zeros_(layer.bias)
                else:
                    layer.weight = torch.nn.init.normal_(layer.weight, mean=0.0, std=0.02)
        return
    
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        if self.dropout:
            x = self.drop(x)
        return self.activation(x)

# Generator unit

In [None]:
class GNet(nn.Module):
    def __init__(self):
        super(GNet, self).__init__()

        self.d1 = Dsample(NUM_CHANNELS, 64, 4, False)
        self.d2 = Dsample(64, 128, 4)
        self.d3 = Dsample(128, 256, 4)
        self.d4 = Dsample(256, 512, 4)
        self.d5 = Dsample(512, 512, 4)
        self.d6 = Dsample(512, 512, 4)
        self.d7 = Dsample(512, 512, 4)
        self.d8 = Dsample(512, 512, 4, False)

        self.u1 = Usample(512, 512, 4, True)
        self.u2 = Usample(1024, 512, 4, True) # has skip connections coming in 512 + 512
        self.u3 = Usample(1024, 512, 4, True) # has skip connections coming in 512 + 512
        self.u4 = Usample(1024, 512, 4) # has skip
        self.u5 = Usample(1024, 256, 4) # has skip
        self.u6 = Usample(512, 128, 4) # has skip
        self.u7 = Usample(256, 64, 4) # has skip

        self.outlayer = nn.ConvTranspose2d(in_channels = 128, out_channels = OUTPUT_CHANNELS, kernel_size = 4, padding = 1, stride = 2)
        self.activation = nn.Tanh()
        self._weight_init()
    

    def _weight_init(self):
        for layer in self.children():
           for name, param in layer.named_parameters(recurse=False):
                if name == 'bias':
                    layer.bias = torch.nn.init.zeros_(layer.bias)
                else:
                    layer.weight = torch.nn.init.normal_(layer.weight, mean=0.0, std=0.02)
        return
    
    def forward(self, x):
        skips = []
        
        # Downsampling
        x = self.d1(x)
        skips.append(x)
        x = self.d2(x)
        skips.append(x)
        x = self.d3(x)
        skips.append(x)
        x = self.d4(x)
        skips.append(x)
        x = self.d5(x)
        skips.append(x)
        x = self.d6(x)
        skips.append(x)
        x = self.d7(x)
        skips.append(x)
        x = self.d8(x)

        skips = skips[::-1]

        #Upsampling
        x = self.u1(x)
        x = torch.cat((x, skips[0]), dim=1)
        x = self.u2(x)
        x = torch.cat((x, skips[1]), dim=1)
        x = self.u3(x)
        x = torch.cat((x, skips[2]), dim=1)
        x = self.u4(x)
        x = torch.cat((x, skips[3]), dim=1)
        x = self.u5(x)
        x = torch.cat((x, skips[4]), dim=1)
        x = self.u6(x)
        x = torch.cat((x, skips[5]), dim=1)
        x = self.u7(x)
        x = torch.cat((x, skips[6]), dim=1)

        x = self.outlayer(x)
        return self.activation(x)

# Discriminator unit

In [None]:
class DNet(nn.Module):
    def __init__(self):
        super(DNet, self).__init__()
        
        self.d1 = Dsample(NUM_CHANNELS, 64, 4, False)
        self.d2 = Dsample(64, 128, 4)
        self.d3 = Dsample(128, 256, 4)
        self.zpad1 = nn.ZeroPad2d(1)
        self.conv1 = nn.Conv2d(256, 512, 4, stride=1, bias=False)
        self.norm = nn.InstanceNorm2d(512, affine=True)
        self.activation = nn.LeakyReLU()
        self.zpad2 = nn.ZeroPad2d(1)
        self.conv2 = nn.Conv2d(512, 1, 4, stride=1)
        self._weight_init()

    
    def _weight_init(self):
        for layer in self.children():
           for name, param in layer.named_parameters(recurse=False):
                if name == 'bias':
                    layer.bias = torch.nn.init.zeros_(layer.bias)
                else:
                    layer.weight = torch.nn.init.normal_(layer.weight, mean=0.0, std=0.02)
        return

    def forward(self, x):
        x = self.d1(x)
        x = self.d2(x)
        x = self.d3(x)
        x = self.zpad1(x)
        x = self.conv1(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.zpad2(x)
        x = self.conv2(x)
        return x

# Photo-to-Monet GAN

In [None]:
class MonetGAN(nn.Module):
    def __init__(self):
        super(MonetGAN, self).__init__()
        self._dev = device
        self.gen = GNet().to(self._dev)
        self.disc = DNet().to(self._dev)

    def _get_loss_d(self, monet_real, photo):
        gen_data = self.disc(self.gen(photo))
        gen_label = torch.zeros(gen_data.shape, device=self._dev)
        real_data = self.disc(monet_real)
        real_label = torch.ones(real_data.shape, device=self._dev)
        labels = torch.cat((real_label, gen_label))
        full_data = torch.cat((real_data, gen_data))
        return torch.nn.BCEWithLogitsLoss(reduction='none')(full_data.mean(dim=1), labels.mean(dim=1))
    
    def _get_loss_g(self, photo):
        gen_data = self.disc(self.gen(photo))
        gen_label = torch.ones(gen_data.shape, device=self._dev)
        return torch.nn.BCEWithLogitsLoss(reduction='none')(gen_data.mean(dim=1), gen_label.mean(dim=1))

# Monet-to-Photo GAN

In [None]:
class PhotoGAN(nn.Module):
    def __init__(self):
        super(PhotoGAN, self).__init__()
        self._dev = device
        self.gen = GNet().to(self._dev)
        self.disc = DNet().to(self._dev)

    def _get_loss_d(self, photo_real, monet):
        gen_data = self.disc(self.gen(monet))
        gen_label = torch.zeros(gen_data.shape, device=self._dev)
        real_data = self.disc(photo_real)
        real_label = torch.ones(real_data.shape, device=self._dev)
        labels = torch.cat((real_label, gen_label))
        full_data = torch.cat((real_data, gen_data))
        return torch.nn.BCEWithLogitsLoss(reduction='none')(full_data.mean(dim=1), labels.mean(dim=1))
    
    def _get_loss_g(self, monet):
        gen_data = self.disc(self.gen(monet))
        gen_label = torch.ones(gen_data.shape, device=self._dev)
        return torch.nn.BCEWithLogitsLoss(reduction='none')(gen_data.mean(dim=1), gen_label.mean(dim=1))

# Complete CycleGAN unit

In [None]:
class CycleGAN(nn.Module):
    def __init__(self, plambda=10):
        super(CycleGAN, self).__init__()
        self._dev = device
        self.mGAN = MonetGAN().to(self._dev)
        self.pGAN = PhotoGAN().to(self._dev)
        self.plambda = plambda

    def _cycle_loss_monet(self, monet, plambda):
        fake_photo = self.mGAN.gen(monet)
        cycled_monet = self.pGAN.gen(fake_photo)
        cycle_loss = torch.nn.L1Loss(reduction='mean')(monet, cycled_monet)
        return plambda * cycle_loss

    def _cycle_loss_photo(self, photo, plambda):
        fake_monet = self.pGAN.gen(photo)
        cycled_photo = self.mGAN.gen(fake_monet)
        cycle_loss = torch.nn.L1Loss(reduction='mean')(photo, cycled_photo)
        return plambda * cycle_loss
    
    def _identity_loss_monet(self, monet, plambda):
        monet_back = self.mGAN.gen(monet)
        iden_loss = torch.nn.L1Loss(reduction='mean')(monet, monet_back)
        return plambda * iden_loss
    
    def _identity_loss_photo(self, photo, plambda):
        photo_back = self.pGAN.gen(photo)
        iden_loss = torch.nn.L1Loss(reduction='mean')(photo, photo_back)
        return plambda * iden_loss

    def _total_loss_D(self, photo, monet):
        return self.mGAN._get_loss_d(monet, photo) + self.pGAN._get_loss_d(photo, monet)
    
    def _total_loss_G(self, photo, monet, mode):
      if mode == 'm':
        #input is photo
        return self.mGAN._get_loss_g(photo) + self._cycle_loss_monet(monet, self.plambda) + self._cycle_loss_photo(photo, self.plambda) +\
               self._identity_loss_monet(monet, 0.5*self.plambda)
      elif mode == 'p':
        #input is monet
        return self.pGAN._get_loss_g(monet) + self._cycle_loss_monet(monet, self.plambda) + self._cycle_loss_photo(photo, self.plambda) +\
               self._identity_loss_photo(photo, 0.5*self.plambda)
      else:
        raise RuntimeError("Wrong mode selected.")

    

    def train(self, in_data, iter_d=1, iter_g=1, n_epochs=100, lr=0.0002):

        optG_M = torch.optim.Adam(list(self.mGAN.gen.parameters()), lr=lr, betas=(0.5, 0.99))
        optG_M.zero_grad(set_to_none=True)
        
        optD_M = torch.optim.Adam(list(self.mGAN.disc.parameters()), lr=lr, betas=(0.5, 0.99))
        optD_M.zero_grad(set_to_none=True)

        optG_P = torch.optim.Adam(list(self.pGAN.gen.parameters()), lr=lr, betas=(0.5, 0.99))
        optG_P.zero_grad(set_to_none=True)
        
        optD_P = torch.optim.Adam(list(self.pGAN.disc.parameters()), lr=lr, betas=(0.5, 0.99))
        optD_P.zero_grad(set_to_none=True)

        for epoch in tqdm(range(n_epochs)):
            for batch_idx, data in enumerate(in_data):
                photo, monet = data['photo'].to(self._dev), data['monet'].to(self._dev)
                
                optG_M.zero_grad(set_to_none=True)
                loss_G_M = self._total_loss_G(photo, monet, 'm')
                loss_G_M.backward(torch.ones_like(loss_G_M))
                optG_M.step()
                #xm.optimizer_step(optG_M, barrier=True)
                optG_M.zero_grad(set_to_none=True)
                
                optD_M.zero_grad(set_to_none=True)
                loss_D_M = self.mGAN._get_loss_d(monet, photo)
                loss_D_M.backward(torch.ones_like(loss_D_M))
                optD_M.step()
                #xm.optimizer_step(optD_M, barrier=True)
                optD_M.zero_grad(set_to_none=True)
                
                optG_P.zero_grad(set_to_none=True)
                loss_G_P = self._total_loss_G(photo, monet, 'p')
                loss_G_P.backward(torch.ones_like(loss_G_P))
                optG_P.step()
                #xm.optimizer_step(optG_P, barrier=True)
                optG_P.zero_grad(set_to_none=True)
                
                optD_P.zero_grad(set_to_none=True)
                loss_D_P = self.pGAN._get_loss_d(monet, photo)
                loss_D_P.backward(torch.ones_like(loss_D_P))
                optD_P.step()
                #xm.optimizer_step(optD_P, barrier=True)
                optD_P.zero_grad(set_to_none=True)
                
                with torch.no_grad():
                  if ((epoch % 10 == 0) & (epoch != 0)):
                    torch.save(CGAN_inst.state_dict(), "./weights.pt")


            with torch.no_grad():
                print(f"E: {epoch}; mDLoss: {loss_D_M.mean().item()}; mGLoss: {loss_G_M.mean().item()}; pDLoss: {loss_D_P.mean().item()}; pGLoss: {loss_G_P.mean().item()}")


# Input Pipeline

In [None]:
class MonetDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        super(MonetDataset, self).__init__()
        self.transform = transform
        self._dev = device

        self.monet_dir = data_dir + "/monet_jpg/"
        self.photo_dir = data_dir + "/photo_jpg/"
        self.monet_list = os.listdir(self.monet_dir)
        self.photo_list = os.listdir(self.photo_dir)
    

    def __len__(self):
        return 1000#max(len(self.monet_list), len(self.photo_list))
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        monet = Image.open(self.monet_dir+self.monet_list[idx % len(self.monet_list)]).convert('RGB')
        photo = Image.open(self.photo_dir+self.photo_list[idx % len(self.photo_list)]).convert('RGB')
        
        monet = self.transform(monet)
        photo = self.transform(photo)
        
        sample = {'monet': monet, 'photo': photo}

        return sample

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [None]:
data_dir = '../input/gan-getting-started'

ds = MonetDataset(data_dir, transform=transform)
dsloader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=2)

# Instantiate CycleGAN

In [None]:
CGAN_inst = CycleGAN()
#load_path = "./weights.pt"
#CGAN_inst.load_state_dict(torch.load(load_path))

# Check Output Before Training

In [None]:
_, ax = plt.subplots(1, 2, figsize=(5, 2), dpi=150)
for i, img in enumerate(dsloader):
    example_photo = img['photo']
    example_photo_toplot = np.transpose(example_photo[0].detach().cpu().numpy(), [1, 2, 0])
    img = (example_photo_toplot * 127.5 + 127.5).astype(np.uint8)
    prediction = CGAN_inst.mGAN.gen(example_photo.to(device))[0].detach().cpu().numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    prediction = np.transpose(prediction, [1, 2, 0])

    ax[0].imshow(img)
    ax[1].imshow(prediction)
    ax[0].set_title("Input Photo")
    ax[1].set_title("Monet-esque")
    ax[0].axis("off")
    ax[1].axis("off")
    break
plt.show()

# Train

In [None]:
CGAN_inst.train(dsloader, n_epochs=8)

# Save Weights

In [None]:
torch.save(CGAN_inst.state_dict(), './weights.pt')

# Check Output

In [None]:
_, ax = plt.subplots(5, 2, figsize=(6, 12), dpi=150)
for i, img in enumerate(dsloader):
    example_photo = img['photo']
    example_photo_toplot = np.transpose(example_photo[0].detach().cpu().numpy(), [1, 2, 0])
    im = (example_photo_toplot * 127.5 + 127.5).astype(np.uint8)
    prediction = CGAN_inst.mGAN.gen(example_photo.to(device))[0].detach().cpu().numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    prediction = np.transpose(prediction, [1, 2, 0])

    ax[i, 0].imshow(im)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    if i == 4:
      break
plt.show()

# Prepare submission ZIP

In [None]:
! mkdir ../images

for i, img in enumerate(dsloader):
    photo = img['photo']
    photo_toplot = np.transpose(example_photo[0].detach().cpu().numpy(), [1, 2, 0])
    im = (photo_toplot * 127.5 + 127.5).astype(np.uint8)
    prediction = CGAN_inst.mGAN.gen(example_photo.to(device))[0].detach().cpu().numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    prediction = np.transpose(prediction, [1, 2, 0])
    im = Image.fromarray(prediction)
    im.save("../images/" + str(i+1) + ".jpg")

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")