In [None]:
import torch
import torch.nn as nn 
import numpy as np  
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from torch.utils.data import Dataset, DataLoader  
import torch.optim as optim 
from tqdm import tqdm 
from torchvision.utils import save_image
import shutil

In [None]:
torch.cuda.empty_cache()
torch.no_grad()

**Configuration**

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "../input/gan-getting-started/"
BATCH_SIZE = 1
LEARNING_RATE = 1e-4 
NUM_WORKERS = 0
NUM_EPOCHS = 2
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_P = "/kaggle/gen_P.pth.tar"
CHECKPOINT_GEN_M = "/kaggle/gen_M.pth.tar"
CHECKPOINT_CRITIC_P = "/kaggle/critic_P.pth.tar"
CHECKPOINT_CRITIC_M = "/kaggle/critic_M.pth.tar"

#WGAN
LAMBDA_GP = 100 
LAMBDA_MEAN = 1 
LAMBDA_cycle_MP = 1e-4 
LAMBDA_CYCLE = 20
LAMBDA_IDENTITY = 10 
LAMBDA_W = 1e-4 

print(DEVICE)

In [None]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

transforms_generate = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

**Discriminator**

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias = False, padding_mode = 'reflect'),
            nn.InstanceNorm2d(out_channels), # try batchnorm
            nn.LeakyReLU(0.2, inplace=True), # try relu
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64,128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.LeakyReLU(0.2, inplace=True), #try Relu
        )
        layers = []
        in_channels = features[0]
        
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
            
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')) # features[-1] = in_channels
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.initial(x)
        return self.model(x)
        #return torch.sigmoid(self.model(x))

In [None]:
def test_discriminator():
    x = torch.randn(1, 3, 256, 256)
    model = Discriminator (in_channels=3)
    preds = model(x) 
    print(preds.shape)

test_discriminator()

**Generator**

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, bias = False, padding_mode="reflect", **kwargs) # bias = False
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, bias = False, **kwargs), # bias = False
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect", bias = False), # bias = False
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
def test_generator():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

test_generator()

**Saving and loading model**

In [None]:
def save_checkpoint(model, optimizer, filename = './mycheckpoint.pth.tar'):
    print( '===> Saving checkpoint ....')
    checkpoint = {
        'state_dict' : model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [None]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('===> Loading checkpoint ....')
    checkpoint = torch.load(checkpoint_file, map_location = DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    # override lr of previous checkpoint
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

**Dataset** 

In [None]:
class MonetPictureDataset(Dataset):
    def __init__(self, root_monet, root_picture, transform=None):
        self.root_monet = root_monet
        self.root_picture = root_picture
        self.transform = transform

        self.monet_images = os.listdir(root_monet)
        self.picture_images = os.listdir(root_picture)
        self.length_dataset = max(len(self.monet_images), len(self.picture_images)) # 1000, 1500
        self.monet_len = len(self.monet_images)
        self.picture_len = len(self.picture_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        monet_img = self.monet_images[index % self.monet_len]
        picture_img = self.picture_images[index % self.picture_len]

        monet_path = os.path.join(self.root_monet, monet_img)
        picture_path = os.path.join(self.root_picture, picture_img)

        monet_img = np.array(Image.open(monet_path).convert("RGB"))
        picture_img = np.array(Image.open(picture_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=monet_img, image0=picture_img)
            monet_img = augmentations["image"]
            picture_img = augmentations["image0"]

        return monet_img, picture_img

**Training**

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    with torch.cuda.amp.autocast(): # necessqry for float 16 (more fast)
        BATCH_SIZE, C, H, W = real.shape
        alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
        interpolated_images = real * alpha + fake * (1 - alpha)

        # Calculate critic scores
        mixed_scores = critic(interpolated_images)

        # Take the gradient of the scores with respect to the images
        gradient = torch.autograd.grad(
          inputs=interpolated_images,
          outputs=mixed_scores,
          grad_outputs=torch.ones_like(mixed_scores),
          create_graph=False,
          retain_graph=False,
        )[0]
        gradient = gradient.view(gradient.shape[0], -1)
        gradient_norm = gradient.norm(2, dim=1)
        gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
        return gradient_penalty

In [None]:
def train_fn(disc_P, disc_M, gen_M, gen_P, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    P_reals = 0
    P_fakes = 0
    Gen_Loss = 0
    Disc_Loss = 0 
    loop = tqdm(loader, leave=True)

    for idx, (monet, picture) in enumerate(loop):
        monet = monet.to(DEVICE)
        picture = picture.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast(): # necessqry for float 16 (more fast)
            fake_picture = gen_P(monet)
            D_P_real = disc_P(picture)
            D_P_fake = disc_P(fake_picture.detach()) # we gonna use fake_picture later on when we train generator

            P_reals += D_P_real.mean().item()
            P_fakes += D_P_fake.mean().item()

            D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
            D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))

            gp_P = gradient_penalty(disc_P,picture,fake_picture, device = DEVICE)
            D_P_loss_W = -(torch.mean(D_P_real.view(-1))-torch.mean(D_P_fake.view(-1))) + LAMBDA_GP*gp_P #reshape

            D_P_loss = D_P_real_loss + D_P_fake_loss + LAMBDA_W * D_P_loss_W

            fake_monet = gen_M(picture)
            D_M_real = disc_M(monet)
            D_M_fake = disc_M(fake_monet.detach())
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))

            gp_M = gradient_penalty(disc_M,monet,fake_monet, device = DEVICE)
            D_M_loss_W = -(torch.mean(D_M_real.view(-1))-torch.mean(D_M_fake.view(-1))) + LAMBDA_GP*gp_M 
            
            D_M_loss = D_M_real_loss + D_M_fake_loss + LAMBDA_W * D_M_loss_W

            # put it togethor
            D_loss = (D_P_loss + D_M_loss)/2 # try modify it
             
            Disc_Loss += D_loss.item()

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward(retain_graph=True) #
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_P_fake = disc_P(fake_picture)
            D_M_fake = disc_M(fake_monet)
            
            loss_G_P_W = -torch.mean(D_P_fake)
            loss_G_M_W = -torch.mean(D_M_fake)

            loss_G_P = mse(D_P_fake, torch.ones_like(D_P_fake)) + LAMBDA_MEAN* loss_G_P_W
            loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake)) + LAMBDA_MEAN* loss_G_M_W
            
            # cycle loss
            cycle_monet = gen_M(fake_picture)
            cycle_picture = gen_P(fake_monet)

            gp_gen_P = gradient_penalty(gen_P,picture,cycle_picture, device = DEVICE)
            gp_gen_M = gradient_penalty(gen_M,monet,cycle_monet, device = DEVICE)

            cycle_monet_loss = l1(monet, cycle_monet) + LAMBDA_cycle_MP * gp_gen_M
            cycle_picture_loss = l1(picture, cycle_picture) + LAMBDA_cycle_MP * gp_gen_P

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_monet = gen_M(monet)
            identity_picture = gen_P(picture)
            identity_monet_loss = l1(monet, identity_monet)
            identity_picture_loss = l1(picture, identity_picture)

            # add all togethor
            G_loss = (
                loss_G_M
                + loss_G_P
                + cycle_monet_loss * LAMBDA_CYCLE
                + cycle_picture_loss * LAMBDA_CYCLE
                + identity_picture_loss * LAMBDA_IDENTITY
                + identity_monet_loss * LAMBDA_IDENTITY
            )
            
            Gen_Loss += G_loss.item()

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward(retain_graph = True)
        g_scaler.step(opt_gen)
        g_scaler.update() 
        
        loop.set_postfix(P_real=P_reals/(idx+1), P_fake=P_fakes/(idx+1), G_L = Gen_Loss/(idx+1), D_L = Disc_Loss/(idx+1) )

In [None]:
def main(disc_P, disc_M, gen_P, gen_M, opt_disc, opt_gen):

    #WGAN end

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_P, gen_P, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_M, gen_M, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_P, disc_P, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_M, disc_M, opt_disc, LEARNING_RATE,
        )

    dataset = MonetPictureDataset(
        root_picture=TRAIN_DIR+"/photo_jpg", root_monet=TRAIN_DIR+"/monet_jpg", transform=transforms
    )
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_P, disc_M, gen_M, gen_P, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

        if SAVE_MODEL:
            save_checkpoint(gen_P, opt_gen, filename=CHECKPOINT_GEN_P)
            save_checkpoint(gen_M, opt_gen, filename=CHECKPOINT_GEN_M)
            save_checkpoint(disc_P, opt_disc, filename=CHECKPOINT_CRITIC_P)
            save_checkpoint(disc_M, opt_disc, filename=CHECKPOINT_CRITIC_M)

In [None]:
disc_P = Discriminator(in_channels=3).to(DEVICE)
disc_M = Discriminator(in_channels=3).to(DEVICE)
gen_M = Generator(img_channels=3, num_residuals=9).to(DEVICE)
gen_P = Generator(img_channels=3, num_residuals=9).to(DEVICE)

opt_disc = optim.Adam(
list(disc_P.parameters()) + list(disc_M.parameters()),
lr= LEARNING_RATE,
betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
list(gen_M.parameters()) + list(gen_P.parameters()),
lr=LEARNING_RATE,
betas=(0.5, 0.999),
)

In [None]:
main(disc_P, disc_M, gen_P, gen_M, opt_disc, opt_gen)

**Load trained model**

In [None]:
load_checkpoint(
CHECKPOINT_GEN_P, gen_P, opt_gen, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_GEN_M, gen_M, opt_gen, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_CRITIC_P, disc_P, opt_disc, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_CRITIC_M, disc_M, opt_disc, LEARNING_RATE,
)

**Generate images**

In [None]:
dataset = MonetPictureDataset(
        root_picture=TRAIN_DIR+"/photo_jpg", root_monet=TRAIN_DIR+"/monet_jpg", transform=transforms_generate
    ) 
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
disc_P.eval()
disc_M.eval()
gen_M.eval()
gen_P.eval()

In [None]:
if not os.path.exists('/kaggle/images'):
    os.makedirs('/kaggle/images')

In [None]:
for idx, (monet, picture) in enumerate(loader):
    with torch.no_grad():
        fake_picture = gen_M(picture.to(DEVICE))
        save_image(fake_picture*0.5+0.5, f'/kaggle/images/picture_{idx}.jpg')

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
