In [11]:
# download dataset
! kaggle datasets download -p data vikramtiwari/pix2pix-dataset 
! cd ./data && unzip pix2pix-dataset.zip

In [56]:
import functools
from pathlib import Path

import cv2
import torch
import numpy as np
import segmentation_models_pytorch as smp
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torch import nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image, make_grid


DATA_DIR = Path("./data/facades/facades")
RUN_EXP_DIR = Path("./runs")
ARTIFACTS_DIR = RUN_EXP_DIR / "artifacts"
CHECKPOINT_DIR = RUN_EXP_DIR / "checkpoints"
ARTIFACTS_DIR.mkdir(exist_ok=True, parents=True)
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

N_EPOCHS = 300
G_LR = 2e-5
D_LR = 2e-5
BATCH_SIZE = 16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data

In [57]:
def flatten(t):
    return [item for sublist in t for item in sublist]

In [119]:
class FacadesDataset(Dataset):
    def __init__(self, data_dir: Path, stage="train", transform=None) -> None:
        self._images_paths = list((data_dir / stage).iterdir())
        self._transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, idx):
        # img = cv2.imread(str(self._images_paths[idx]))
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # h, w, _ = img.shape
        # img_target = img[:, : w // 2]
        # img_input = img[:, w // 2:]
        
        img = Image.open(self._images_paths[idx])
        w, h = img.size
        img_A = img.crop((0, 0, w / 2, h))
        img_B = img.crop((w / 2, 0, w, h))
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        if self._transform:
            # transformed = self._transform(image=img_target, image_input=img_input)
            # img_target, img_input = transformed["image"], transformed["image_input"]
            img_target = self._transform(img_A)
            img_input = self._transform(img_B)

        return img_input, img_target

    def __len__(self):
        return len(self._images_paths)

train_transform = A.Compose([
    # A.Resize(height=256, width=256),
    # A.HorizontalFlip(p=0.5),
    # A.CoarseDropout(max_holes=8, max_height=8, max_width=8),
    ToTensorV2()
], additional_targets={'image_input': 'image'})

transform = A.Compose([
    # A.Resize(height=256, width=256),
    ToTensorV2()
], additional_targets={'image_input': 'image'})

In [120]:
train_ds = FacadesDataset(DATA_DIR, "train", train_transform)
valid_ds = FacadesDataset(DATA_DIR, "val", transform)
test_ds = FacadesDataset(DATA_DIR, "test", transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# Generator

In [132]:
class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.unet = smp.Unet(
            "resnet34",
            in_channels=3,
            classes=3,
            encoder_depth=5,
            encoder_weights="imagenet",
            decoder_channels=(256, 128, 64, 32, 16)
        )

    def forward(self, x):
        out = self.unet(x)
        return out

# class Generator(nn.Module):
#     """Create a Unet-based generator"""

#     def __init__(self, input_nc=3, output_nc=3, nf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
#         """Construct a Unet generator
#         Parameters:
#             input_nc (int)  -- the number of channels in input images
#             output_nc (int) -- the number of channels in output images
#             num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
#                                 image of size 128x128 will become of size 1x1 # at the bottleneck
#             nf (int)       -- the number of filters in the last conv layer
#             norm_layer      -- normalization layer
#         We construct the U-Net from the innermost layer to the outermost layer.
#         It is a recursive process.
#         """
#         super().__init__()
#         # construct unet structure
#         unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        
#         # add intermediate layers with ngf * 8 filters
#         unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
#         unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
#         unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        
#         # gradually reduce the number of filters from nf * 8 to nf
#         unet_block = UnetSkipConnectionBlock(nf * 4, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
#         unet_block = UnetSkipConnectionBlock(nf * 2, nf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
#         unet_block = UnetSkipConnectionBlock(nf, nf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
#         self.model = UnetSkipConnectionBlock(output_nc, nf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

#     def forward(self, input):
#         """Standard forward"""
#         return self.model(input)
    

# class UnetSkipConnectionBlock(nn.Module):
#     """Defines the Unet submodule with skip connection.
#         X -------------------identity----------------------
#         |-- downsampling -- |submodule| -- upsampling --|
#     """

#     def __init__(self, outer_nc, inner_nc, input_nc=None,
#                  submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
#         """Construct a Unet submodule with skip connections.
#         Parameters:
#             outer_nc (int) -- the number of filters in the outer conv layer
#             inner_nc (int) -- the number of filters in the inner conv layer
#             input_nc (int) -- the number of channels in input images/features
#             submodule (UnetSkipConnectionBlock) -- previously defined submodules
#             outermost (bool)    -- if this module is the outermost module
#             innermost (bool)    -- if this module is the innermost module
#             norm_layer          -- normalization layer
#             use_dropout (bool)  -- if use dropout layers.
#         """
#         super().__init__()
#         self.outermost = outermost
#         if input_nc is None:
#             input_nc = outer_nc
#         downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
#                              stride=2, padding=1, bias=False)
#         downrelu = nn.LeakyReLU(0.2, True)
#         downnorm = norm_layer(inner_nc)
#         uprelu = nn.ReLU(True)
#         upnorm = norm_layer(outer_nc)

#         if outermost:
#             upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1)
#             down = [downconv]
#             up = [uprelu, upconv, nn.Tanh()]
#             model = down + [submodule] + up
#         elif innermost:
#             upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1, bias=False)
#             down = [downrelu, downconv]
#             up = [uprelu, upconv, upnorm]
#             model = down + up
#         else:
#             upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
#                                         kernel_size=4, stride=2,
#                                         padding=1, bias=False)
#             down = [downrelu, downconv, downnorm]
#             up = [uprelu, upconv, upnorm]

#             if use_dropout:
#                 model = down + [submodule] + up + [nn.Dropout(0.5)]
#             else:
#                 model = down + [submodule] + up

#         self.model = nn.Sequential(*model)

#     def forward(self, x):
#         if self.outermost:
#             return self.model(x)
#         else:   # add skip connections
#             return torch.cat([x, self.model(x)], 1)

# Discriminator

In [133]:
class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(Discriminator, self).__init__()
        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=False),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

# Train

In [134]:
generator = Generator().to(DEVICE)
discriminator = Discriminator(6, 64).to(DEVICE)

G_optimizer = torch.optim.Adam(generator.parameters(), lr=G_LR, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=D_LR, betas=(0.5, 0.999))

adversarial_loss = nn.BCELoss() 
l1_loss = nn.L1Loss()

def generator_loss(generated_image, target_img, G, real_target):
    gen_loss = adversarial_loss(G, real_target)
    l1_l = l1_loss(generated_image, target_img)
    gen_total_loss = gen_loss + (100 * l1_l)
    return gen_total_loss

def discriminator_loss(output, label):
    return adversarial_loss(output, label)

In [135]:
D_epoch_losses, G_epoch_losses= [], []
for epoch in tqdm(range(1, N_EPOCHS + 1)): 
    D_loss_list, G_loss_list = [], []
   
    for input_img, target_img in train_loader:
        input_img = input_img.to(DEVICE)
        target_img = target_img.to(DEVICE)
       
        # ground truth labels real and fake
        real_target = torch.ones(input_img.size(0), 1, 30, 30, requires_grad=True).to(DEVICE)
        fake_target = torch.zeros(input_img.size(0), 1, 30, 30, requires_grad=True).to(DEVICE)
    
        # generator forward pass
        generated_img = generator(input_img.float())
        
        # train discriminator with fake/generated images
        D_fake_input = torch.cat((input_img, generated_img), 1)
        D_fake_output = discriminator(D_fake_input.detach())
        D_fake_loss = discriminator_loss(D_fake_output, fake_target)
        
        # train discriminator with real images
        D_real_input = torch.cat((input_img, target_img), 1)        
        D_real_output = discriminator(D_real_input)
        D_real_loss = discriminator_loss(D_real_output,  real_target)

        # average discriminator loss
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss.item())
        
        # compute gradients and run optimizer step
        D_optimizer.zero_grad()
        D_total_loss.backward()
        D_optimizer.step()
        
        
        # Train generator with real labels
        G_fake_input = torch.cat((input_img, generated_img), 1)
        G = discriminator(G_fake_input)
        G_loss = generator_loss(generated_img, target_img, G, real_target)                                 
        G_loss_list.append(G_loss.item())
        
        # compute gradients and run optimizer step
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
            
    D_epoch_losses.append(np.mean(D_loss_list))
    G_epoch_losses.append(np.mean(G_loss_list))
     
    print(f"[Epoch {epoch}/{N_EPOCHS}] D_loss: {D_epoch_losses[-1]} G_loss: {G_epoch_losses[-1]}")

    if epoch % 50 == 0 or epoch == 1:
        torch.save(generator.state_dict(), f"./artifacts/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"./artifacts/discriminator_epoch_{epoch}.pth")

        for input_img, target_img in valid_loader:
            input_img = input_img.to(DEVICE)
            target_img = target_img.to(DEVICE)
            generated_img = generator(input_img)
            save_image(
                flatten(
                    [[input_img[i], target_img[i], generated_img[i]] for i in range(3)]
                ), 
                f"./pix2pix_sample_{epoch}.png"
            )
            break

    clear_output()
    plt.plot(D_epoch_losses)
    plt.plot(G_epoch_losses)
    plt.legend(['Discriminator loss', 'Generator loss'])

  0%|          | 0/300 [00:09<?, ?it/s]


KeyboardInterrupt: 