## Pix2Pix Implementation from Scratch

Resource Link: https://youtu.be/SuddDSqGRzg

Paper Link: https://arxiv.org/abs/1611.07004

In [1]:
print("Pix2Pix Implementation from Scratch using Pytorch")

Pix2Pix Implementation from Scratch using Pytorch


In [2]:
# install the albumentations library in the python environment
# %pip install albumentations

### Defining the discriminator


In [3]:
import torch
import torch.nn as nn

In [4]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

    


In [5]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64,128,256,512]):   # 256 -> 30x30
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        self.model = nn.Sequential(*layers)     # unpack all layers and put it into nn.Sequential

    
    def forward(self, x, y):
        x = torch.cat([x,y], dim=1)
        x = self.initial(x)
        return self.model(x)



In [6]:
# Test Case: test the Discriminator 
def test():
    x = torch.randn((1,3,256,256))
    y = torch.randn((1,3,256,256))
    model = Discriminator()
    preds = model(x,y)
    print(preds.shape)

test()

torch.Size([1, 1, 30, 30])


### Defining the generator

In [7]:
import torch
import torch.nn as nn

In [8]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4,2,1, bias=False, padding_mode="reflect") 
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4,2,1, bias=False),

            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


In [9]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4,2,1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )       # 128

        self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False)        # 64
        self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False)        # 32      
        self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False)        # 16
        self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)        # 8
        self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)        # 4
        self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)        # 2

        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4,2,1, padding_mode="reflect"),
            nn.ReLU(),
        )       # 1x1

        self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
        self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
        self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
        self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
        
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, 4,2,1),
            nn.Tanh(),
        )


    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1,d7], 1))
        up3 = self.up3(torch.cat([up2,d6], 1))
        up4 = self.up4(torch.cat([up3,d5], 1))
        up5 = self.up5(torch.cat([up4,d4], 1))
        up6 = self.up6(torch.cat([up5,d3], 1))
        up7 = self.up7(torch.cat([up6,d2], 1))
        return self.final_up(torch.cat([up7, d1],1))



In [10]:
def test_generator():
    x = torch.randn((1,3,256,256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)

test_generator()

torch.Size([1, 3, 256, 256])


### Config Parameters

In [11]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [12]:
DEVICE = "mps:0" if torch.backends.mps.is_available() else "cpu"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 0
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 200
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

In [13]:
both_transform = A.Compose(
    [A.Resize(width=256, height=256),], 
    additional_targets = {"image0": "image"},
)

transform_only_input = A.Compose(
    [
        # A.ColorJitter(p=0.2)
        A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)

### Load the Dataset

In [14]:
# %pip install kagglehub

In [15]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

In [16]:
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        print(self.list_files)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image, target_image = augmentations["image"], augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_input(image=target_image)["image"]

        return input_image, target_image


### Some utility functions to view intermediate results

In [17]:
import torch
from torchvision.utils import save_image

In [25]:
def save_some_examples(gen, val_loader, epoch, folder):
    x,y = next(iter(val_loader))
    x,y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5     # remove normalization
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 * 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 * 0.5, folder + f"/label_{epoch}.png")
    gen.train()
    

In [26]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)
    


In [27]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint and it will lead to many hours of debugging.
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

### Training the model

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm


In [29]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scalar, d_scalar):
    loop = tqdm(loader, leave=True)

    for idx, (x,y) in enumerate(loop):
        x,y = x.to(DEVICE), y.to(DEVICE)

        # Train Discriminator
        with torch.amp.autocast(device_type="mps", dtype=torch.bfloat16):
            y_fake = gen(x)
            D_real = disc(x,y)
            D_fake = disc(x, y_fake.detach())
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss)/2

        disc.zero_grad()
        d_scalar.scale(D_loss).backward()
        d_scalar.step(opt_disc)
        d_scalar.update()

        # Train the Generator
        with torch.amp.autocast(device_type="mps", dtype=torch.bfloat16):
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scalar.scale(G_loss).backward()
        g_scalar.step(opt_gen)
        g_scalar.update()
    

In [30]:
def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5,0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5,0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)

    train_dataset = MapDataset(root_dir="pix2pix_dataset/maps/maps/train")
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    g_scalar = torch.amp.GradScaler("mps")
    d_scalar = torch.amp.GradScaler("mps")

    val_dataset = MapDataset(root_dir="pix2pix_dataset/maps/maps/val")
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(NUM_EPOCHS):
        train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scalar, d_scalar)

        if SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="evaluation")



In [31]:
main()

['63.jpg', '823.jpg', '189.jpg', '77.jpg', '837.jpg', '638.jpg', '604.jpg', '162.jpg', '176.jpg', '88.jpg', '610.jpg', '348.jpg', '360.jpg', '406.jpg', '412.jpg', '374.jpg', '1019.jpg', '599.jpg', '1025.jpg', '1031.jpg', '228.jpg', '214.jpg', '572.jpg', '566.jpg', '200.jpg', '957.jpg', '943.jpg', '994.jpg', '758.jpg', '980.jpg', '770.jpg', '764.jpg', '765.jpg', '771.jpg', '981.jpg', '759.jpg', '995.jpg', '942.jpg', '956.jpg', '567.jpg', '201.jpg', '215.jpg', '573.jpg', '229.jpg', '1030.jpg', '1024.jpg', '598.jpg', '1018.jpg', '413.jpg', '375.jpg', '361.jpg', '407.jpg', '349.jpg', '177.jpg', '611.jpg', '89.jpg', '605.jpg', '163.jpg', '639.jpg', '188.jpg', '836.jpg', '76.jpg', '822.jpg', '62.jpg', '74.jpg', '834.jpg', '60.jpg', '820.jpg', '48.jpg', '808.jpg', '149.jpg', '613.jpg', '175.jpg', '161.jpg', '607.jpg', '388.jpg', '439.jpg', '377.jpg', '411.jpg', '405.jpg', '363.jpg', '1032.jpg', '1026.jpg', '559.jpg', '203.jpg', '565.jpg', '571.jpg', '217.jpg', '940.jpg', '798.jpg', '954.jpg',

100%|██████████| 69/69 [00:39<00:00,  1.76it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:38<00:00,  1.80it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:38<00:00,  1.78it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:39<00:00,  1.77it/s]
100%|██████████| 69/69 [00:38<00:00,  1.80it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:38<00:00,  1.78it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:37<00:00,  1.82it/s]
100%|██████████| 69/69 [00:37<00:00,  1.84it/s]
100%|██████████| 69/69 [00:37<00:00,  1.84it/s]
100%|██████████| 69/69 [00:37<00:00,  1.85it/s]
100%|██████████| 69/69 [00:37<00:00,  1.84it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:38<00:00,  1.80it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:39<00:00,  1.76it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:38<00:00,  1.80it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:38<00:00,  1.77it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:38<00:00,  1.78it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:39<00:00,  1.75it/s]
100%|██████████| 69/69 [00:38<00:00,  1.78it/s]
100%|██████████| 69/69 [00:38<00:00,  1.78it/s]
100%|██████████| 69/69 [00:38<00:00,  1.79it/s]
100%|██████████| 69/69 [00:39<00:00,  1.75it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:38<00:00,  1.81it/s]
100%|██████████| 69/69 [00:39<00:00,  1.76it/s]
100%|██████████| 69/69 [00:37<00:00,  1.82it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]
100%|██████████| 69/69 [00:37<00:00,  1.83it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [18:28<00:00, 16.06s/it]    
100%|██████████| 69/69 [00:38<00:00,  1.80it/s]
 10%|█         | 7/69 [00:04<00:37,  1.65it/s]


KeyboardInterrupt: 