# pix2pix


In [None]:
import torch
from torch import nn
import numpy as np
import os
import PIL
from PIL import Image
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from albumentations.pytorch import ToTensorV2
import torch.optim as optim
from torch.autograd import Variable

device = torch.device("cuda:1" if (torch.cuda.is_available()) else "cpu")

## Some utils

In [None]:
from torchvision.utils import save_image

def save_some_examples(gen, val_loader, epoch, path):
    x, y = next(iter(val_loader))
    x, y = x.to(device), y.to(device)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization
        save_image(y_fake, path + f"y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, path + f"input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, path + f"label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
class AnimeDataset(Dataset):
    def __init__(self, root_dir, img_size=512): # im_size = real img size, but im_size for model is 256
        self.img_size = img_size
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        self.transform_input = A.Compose([
                A.Resize(width=256, height=256),
                A.HorizontalFlip(p=0.5),
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ])
        self.transform_target = A.Compose([
                A.Resize(width=256, height=256),
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                #A.ColorJitter(p=0.2),
                ToTensorV2(),
            ])

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        
        input_image = image[:, self.img_size:, :]   # change if it's needed, im my case I had label 1st and input 2nd
        target_image = image[:, :self.img_size, :]
        
        input_image = self.transform_input(image=input_image)["image"]
        target_image = self.transform_target(image=target_image)["image"]
        
        return input_image, target_image

## Descriminator architecture
Where Ck = Convolution-BatchNorm-ReLU layer
with k filters

The 70 × 70 PatchGAN discriminator architecture is:
C64-C128-C256-C512

In [None]:
class CBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=2):
        super(CBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 4, stride, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2),
        )
        
    def forward(self, x):
        return self.block(x)
    
    
class Desc(nn.Module):
    def __init__(self, d=64):
        super(Desc, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, d, 4, 2, 1),   # CHANGE IN_CHANNELS TO 5 IF IT'S COLORIZATION
            nn.LeakyReLU(0.2),
            CBlock(d, d*2),
            CBlock(d*2, d*4),
            CBlock(d*4, d*8, 1),
            nn.Conv2d(d*8, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        xy = torch.cat([x,y], dim=1)
        return self.model(xy)

## Generator architecture
U-net generator

In [None]:
class CDBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=2, use_dropout=False):
        super(CDBlock, self).__init__()
        self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride, 1, bias=False)
        self.batch = nn.BatchNorm2d(out_ch)
        self.drop = nn.Dropout(0.5)
        self.act = nn.ReLU()
        self.use_dropout = use_dropout
        
    def forward(self, x):
        x = self.conv(x)
        x = self.batch(x)
        if self.use_dropout:
            x = self.drop(x)
        return self.act(x)

    
class Gen(nn.Module):
    def __init__(self, d=64):
        super(Gen, self).__init__()
        # encoder
        self.down1 = nn.Sequential(
            nn.Conv2d(3, d, 4, 2, 1), # !!! in 2 channels if it's colorization
            nn.LeakyReLU(0.2),
        )
        self.down2 = CBlock(d, d*2)
        self.down3 = CBlock(d*2, d*4)
        self.down4 = CBlock(d*4, d*8)
        self.down5 = CBlock(d*8, d*8)
        self.down6 = CBlock(d*8, d*8)
        self.down7 = CBlock(d*8, d*8)
        self.down8 = nn.Sequential(
            nn.Conv2d(d*8, d*8, 4, 2, 1), nn.ReLU()
        )
        # decoder
        self.up1 = CDBlock(d*8, d*8, use_dropout=True)
        self.up2 = CDBlock(d*8*2, d*8, use_dropout=True)
        self.up3 = CDBlock(d*8*2, d*8, use_dropout=True)
        self.up4 = CDBlock(d*8*2, d*8)
        self.up5 = CDBlock(d*8*2, d*4)
        self.up6 = CDBlock(d*4*2, d*2)
        self.up7 = CDBlock(d*2*2, d)
        self.up8 = nn.ConvTranspose2d(d*2, 3, 4, 2, 1) 
        
    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        u8 = self.up8(torch.cat([u7, d1], 1))
        
        return torch.tanh(u8)

## Weights init

In [None]:
def weights_init(m, mean=0.0, std=0.02):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, mean, std)
        
netG = Gen().to(device)
netG.apply(weights_init) 
netD = Desc().to(device)
netD.apply(weights_init) 
print('weights initialized')

weights initialized


# Define params and train

In [None]:
im_size = 512
dataset = AnimeDataset("sketch/", im_size)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset, batch_size=16, shuffle=False)

L1_LAMBDA = 100
BCE_loss = nn.BCEWithLogitsLoss().to(device)
L1_loss = nn.L1Loss().to(device)

# Adam optimizer
optG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
epochs_num = 101
for epoch in range(epochs_num):
    for img, target in loader:
        img, target = img.to(device), target.to(device)
        
        ### train discriminator ###
        # real images
        D_real = netD(img, target)
        real_loss = BCE_loss(D_real, torch.ones_like(D_real))
        
        # fake loss
        fake = netG(img)
        D_fake = netD(img, fake.detach())
        fake_loss = BCE_loss(D_fake, torch.zeros_like(D_fake))
        
        D_loss = (real_loss + fake_loss)/2
        netD.zero_grad()
        D_loss.backward()
        optD.step()
        
        ### train generator ###
        D_fake = netD(img, fake)
        G_fake_loss = BCE_loss(D_fake, torch.ones_like(D_fake))
        L1 = L1_loss(fake, target) * L1_LAMBDA
        
        G_loss = G_fake_loss + L1
        netG.zero_grad()
        G_loss.backward()
        optG.step()
    
    if (epoch%10==0):
        save_some_examples(netG, val_loader, epoch)
        
    if (epoch%10==0):
        save_checkpoint(netG, optG, filename="netG.pth.tar")
        save_checkpoint(netD, optD, filename="netD.pth.tar")
    print(epoch)

# Evaluate

Do we really need model.eval?

In [None]:
load_checkpoint("netG.pth.tar", netG, optG)
dataset = AnimeDataset("/test/", im_size)
loader = DataLoader(dataset, batch_size=16)
#netG.eval()
for i in range(5):
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_fake = netG(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization
        save_image(y_fake, f"y_gen_{i}.png")
        save_image(x * 0.5 + 0.5, f"input_{i}.png")
    save_image(y * 0.5 + 0.5, f"label_{i}.png")
#netG.train()    

=> Loading checkpoint
