In [2]:
%matplotlib inline
import os
import glob
import random
import itertools
from PIL import Image

import torch

from torchvision.utils import make_grid
from torchvision import transforms
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

import wandb

pl.seed_everything(hash("kek") % 2**32 - 1)

Global seed set to 4109070765


4109070765

In [15]:
dataset_name = 'monet2photo'
DATA_PATH = f"../input/{dataset_name}"
TRAIN_DIR_A = "trainA"  # monet
TRAIN_DIR_B = "trainB"  # photos
TEST_DIR_A = "testA"  # monet
TEST_DIR_B = "testB"  # photos

PATHS = {
    "train": {
        "A": os.path.join(DATA_PATH, TRAIN_DIR_A),
        "B": os.path.join(DATA_PATH, TRAIN_DIR_B),
    },
    "test": {
        "A": os.path.join(DATA_PATH, TEST_DIR_A),
        "B": os.path.join(DATA_PATH, TEST_DIR_B),
    },
}

In [3]:
wandb_logger = WandbLogger(project="CycleGAN", name="version2", log_model="all")

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [6]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

In [7]:
class ImageTransform:
    def __init__(self, img_size=256):
        self.data_transforms = {
            "train": transforms.Compose(
                [
                    transforms.Resize(int(img_size * 1.21), Image.Resampling.BICUBIC),
                    transforms.RandomCrop(img_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            ),
            "test": transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            ),
        }

    def __call__(self, img, phase="train"):
        img = self.data_transforms[phase](img)

        return img


class MonetDataset(Dataset):
    def __init__(self, path_A: str, path_B: str, transform, phase="train"):
        self.path_A = glob.glob(os.path.join(path_A, "*jpg"))
        self.path_B = glob.glob(os.path.join(path_B, "*jpg"))
        self.transform = transform
        self.phase = phase

    def __len__(self):
        return min([len(self.path_A), len(self.path_B)])

    def __getitem__(self, idx):
        path_A = self.path_A[idx]
        path_B = self.path_B[idx]
        imgA = Image.open(path_A)
        imgB = Image.open(path_B)

        imgA = self.transform(imgA, self.phase)
        imgB = self.transform(imgB, self.phase)

        return imgA, imgB

In [None]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = PATHS, batch_size: int = 1):
        super().__init__()
        self.data_path = data_dir
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            MonetDataset(
                self.data_path["train"]["A"],
                self.data_path["train"]["B"],
                ImageTransform(),
            ),
            batch_size=self.batch_size,
        )

    def test_dataloader(self):
        return DataLoader(
            MonetDataset(
                self.data_path["test"]["A"],
                self.data_path["test"]["B"],
                ImageTransform(),
                phase="test",
            ),
            batch_size=self.batch_size,
        )

In [8]:
class ImagePool:
    """This class implements an image buffer that stores previously generated images.
    This buffer enables us to update discriminators using a history of generated images
    rather than the ones produced by the latest generators.
    """

    def __init__(self, pool_size):
        """Initialize the ImagePool class
        Parameters:
            pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
        """
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """Return an image from the pool.
        Parameters:
            images: the latest generated images from the generator
        Returns images from the buffer.
        By 50/100, the buffer will return input images.
        By 50/100, the buffer will return images previously stored in the buffer,
        and insert the current images to the buffer.
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if (
                self.num_imgs < self.pool_size
            ):  # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if (
                    p > 0.5
                ):  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(
                        0, self.pool_size - 1
                    )  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:  # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)  # collect all the images and return
        return return_images

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_features,
                in_features,
                kernel_size=3,
                padding=1,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_features,
                in_features,
                kernel_size=3,
                padding=1,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x) + x


class DiscriminatorBlock(nn.Module):
    def __init__(
        self,
        in_filters: int,
        out_filters: int,
        kernel_size: int,
        stride: int,
        normalize: bool = True,
    ):

        super().__init__()

        layers = [
            nn.Conv2d(
                in_filters,
                out_filters,
                kernel_size=kernel_size,
                stride=stride,
                padding=1,
            )
        ]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """x: torch.Tensor"""
        return self.layers(x)

In [10]:
class GeneratorResNet(nn.Module):
    """
    Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
    We adapt Torch code and idea from Justin Johnson's neural style transfer project
    (https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_channels: int, n_residual_blocks: int = 9):
        super().__init__()

        def make_generators_layers(
            input_channels: int, n_residual_blocks: int
        ) -> nn.Sequential:
            """returns resnet layers for resnet generator"""
            out_features = 64

            layers = [
                nn.Conv2d(
                    input_channels,
                    out_features,
                    kernel_size=7,
                    padding=3,
                    padding_mode="reflect",
                ),
                nn.ReLU(inplace=True),
            ]

            in_features = out_features

            for _ in range(2):
                out_features *= 2
                layers += [
                    nn.Conv2d(
                        in_features, out_features, kernel_size=3, stride=2, padding=1
                    ),
                    nn.InstanceNorm2d(out_features),
                    nn.ReLU(inplace=True),
                ]

                in_features = out_features

            for _ in range(n_residual_blocks):
                layers += [ResidualBlock(out_features)]

            for _ in range(2):
                out_features //= 2
                layers += [
                    nn.Upsample(
                        scale_factor=2
                    ),  # https://distill.pub/2016/deconv-checkerboard/
                    nn.Conv2d(
                        in_features, out_features, kernel_size=3, stride=1, padding=1
                    ),
                    nn.InstanceNorm2d(out_features),
                    nn.ReLU(inplace=True),
                ]

                in_features = out_features

            layers += [
                nn.Conv2d(
                    out_features, input_channels, 7, padding=3, padding_mode="reflect"
                ),
                nn.Tanh(),
            ]
            return nn.Sequential(*layers)

        self.layers = make_generators_layers(input_channels, n_residual_blocks)

        self.apply(weights_init_normal)

    def forward(self, x):
        return self.layers(x)


class Discriminator(nn.Module):
    def __init__(self, input_channels: int):
        super().__init__()

        def make_discriminators_layers(input_channels, base_channel_size=64):
            """makes layers for discriminator"""
            bcs = base_channel_size
            layers = [
                DiscriminatorBlock(
                    input_channels, bcs, kernel_size=4, stride=2, normalize=False
                ),
                DiscriminatorBlock(bcs, bcs * 2, kernel_size=4, stride=2),
                DiscriminatorBlock(bcs * 2, bcs * 4, kernel_size=4, stride=2),
                DiscriminatorBlock(bcs * 4, bcs * 8, kernel_size=4, stride=1),
                nn.Conv2d(bcs * 8, 1, kernel_size=4, stride=1, padding=1),
                # nn.ZeroPad2d((1, 1, 1, 1)),
            ]

            return nn.Sequential(*layers)

        self.layers = make_discriminators_layers(input_channels, base_channel_size=64)

        self.apply(weights_init_normal)

    def forward(self, image):
        return self.layers(image)

In [11]:
def test_generator():
    g = GeneratorResNet(3)
    sample_image = torch.randn(1, 3, 256, 256)
    assert g(sample_image).shape == sample_image.shape
    print("gen is ok")


def test_discriminator():
    d = Discriminator((3))
    assert d(torch.randn(1, 3, 256, 256)).shape == torch.Size([1, 1, 30, 30])
    print("dis is ok")


test_generator()
test_discriminator()

gen is ok
dis is ok


In [12]:
def set_requires_grad(nets, requires_grad):
    for net in nets:
        for param in net.parameters():
            param.requires_grad = requires_grad

In [14]:
class CycleGAN(pl.LightningModule):
    def __init__(
        self,
        G_X,
        G_Y,
        D_X,
        D_Y,
        lr=2e-4,
        betas=(0.5, 0.999),
        cyclic_loss_coef=10,
        identity_loss_coef=5,
        *args,
        **kwargs
    ):
        super().__init__()

        self.G_X = G_X  # style to base X
        self.G_Y = G_Y  # base to style Y
        self.D_X = D_X  # detect based X
        self.D_Y = D_Y  # detect styled Y

        self.fakePoolX = ImagePool(50)
        self.fakePoolY = ImagePool(50)
        self.identity_loss = nn.L1Loss()
        self.gan_loss = nn.MSELoss()
        self.cycle_loss = nn.L1Loss()

        self.save_hyperparameters()

    def configure_optimizers(self):
        optG = Adam(
            itertools.chain(self.G_X.parameters(), self.G_Y.parameters()),
            lr=self.hparams.lr,
            betas=self.hparams.betas,
        )

        optD = Adam(
            itertools.chain(self.D_X.parameters(), self.D_Y.parameters()),
            lr=self.hparams.lr,
            betas=self.hparams.betas,
        )


        gamma = lambda epoch: 1 - max(0, epoch + 1 - 100) / 101

        schG = LambdaLR(optG, lr_lambda=gamma)
        schD = LambdaLR(optD, lr_lambda=gamma)

        return [optG, optD], [schG, schD]

    def training_step(self, batch, batch_idx, optimizer_idx):
        x_batch, y_batch = batch  # A and B folders

        discriminator_requires_grad = optimizer_idx == 1
        set_requires_grad([self.D_X, self.D_Y], discriminator_requires_grad)

        x_batch = self.fakePoolX.query(x_batch)
        y_batch = self.fakePoolY.query(y_batch)

        b = x_batch.size()[0]
        true_labels = (torch.randn(b, 1, 30, 30) * 0.3 + 1).type_as(
            x_batch
        )  # label smoothing Uniform ~ [0.7, 1]
        fake_labels = torch.zeros(b, 1, 30, 30).type_as(x_batch)

        # Train Generator
        if optimizer_idx == 0:

            # For painting→photo, we find that it is helpful to introduce an additional
            # loss to encourage the mapping to preserve color composition between the input
            # and output. In particular, we adopt the technique of Taigman et al.
            # https://arxiv.org/pdf/1611.02200.pdf

            loss_identity = self.identity_loss(
                self.G_X(x_batch), x_batch
            ) + self.identity_loss(
                self.G_Y(y_batch),
                y_batch,
            )

            x_batch_hat = self.G_X(y_batch)
            y_batch_hat = self.G_Y(x_batch)

            # Adversarial loss
            loss_gan = self.gan_loss(
                self.D_X(x_batch_hat), true_labels
            ) + self.gan_loss(self.D_Y(y_batch_hat), true_labels)

            loss_cycle = self.cycle_loss(self.G_X(y_batch), x_batch) + self.cycle_loss(
                self.G_Y(x_batch), y_batch
            )

            loss_generator = (
                loss_gan
                + self.hparams.cyclic_loss_coef * loss_cycle
                + self.hparams.identity_loss_coef * loss_identity
            )

            self.log("generator/loss", loss_generator)
            self.log("generator/adversarial", loss_gan)
            self.log("generator/cycle", loss_cycle)
            self.log("generator/identity", loss_identity)

            return {
                "loss": loss_generator,
                "adversarial": loss_gan,
                "cycle": loss_cycle,
                "identity": loss_identity,
            }

        # Train Discriminator
        else:
            x_batch_hat = self.G_X(y_batch)
            y_batch_hat = self.G_Y(x_batch)

            if self.global_step % 499 == 0:
                temp = (
                    make_grid(
                        torch.cat(
                            list(
                                map(
                                    lambda x: x.cpu(),
                                    [x_batch, y_batch_hat, y_batch, x_batch_hat],
                                )
                            )
                        ),
                        nrow=4,
                        padding=0,
                    )
                    .permute(1, 2, 0)
                    .detach()
                    .numpy()
                )
                temp = temp * 0.5 + 0.5
                temp = temp * 255.0
                temp = temp.astype(int)
                wandb.log(
                    {
                        "test_images": wandb.Image(
                            temp,
                            caption="Two on the left: Monet to real image. On the right: vice a versa",
                        )
                    }
                )

            loss_discriminator = (
                self.gan_loss(self.D_X(x_batch), true_labels)
                + self.gan_loss(self.D_X(x_batch_hat), fake_labels)
                + self.gan_loss(self.D_Y(y_batch), true_labels)
                + self.gan_loss(self.D_Y(y_batch_hat), fake_labels)
            )

            self.log("discriminator/loss", loss_discriminator)

            return {"loss": loss_discriminator}

In [12]:
checkpoint_callback = ModelCheckpoint(every_n_epochs=25)

In [None]:
dm = MonetDataModule()

G_base = GeneratorResNet(3)
G_style = GeneratorResNet(3)
D_base = Discriminator(3)
D_style = Discriminator(3)

checkpoint_callback = ModelCheckpoint(every_n_epochs=25)
# model_list = [G_base, G_style, D_base, D_style]
# path_list = ["g_base139.pt", "g_style139.pt", "d_base139.pt", "d_style139.pt"]

# LightningModule  --------------------------------------------------------------
model = CycleGAN(G_base, G_style, D_base, D_style)

# Trainer  --------------------------------------------------------------
trainer = Trainer(
    logger=wandb_logger,
    max_epochs=200,
    gpus=0,
    reload_dataloaders_every_n_epochs=5,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback, TQDMProgressBar()],
)

In [None]:
trainer.fit(model, datamodule=dm)

Training: 300it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
