In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
%cd '//content//drive//My Drive//Pix2Pix'
%ls

/content/drive/My Drive/Pix2Pix
config.py  dataset.py    discriminator_model.py  gen.pth.tar   train.py
[0m[01;34mdata[0m/      disc.pth.tar  generator_model.py      [01;34m__pycache__[0m/  utils.py


In [None]:
import os

train_folder = "//content//drive//My Drive//Pix2Pix//data//maps//1000dataset//"
val_folder = "//content//drive//My Drive//Pix2Pix//data//maps//val//"

tfiles = len(os.listdir(train_folder))
vfiles = len(os.listdir(val_folder))

print(tfiles, vfiles)

1000 3


In [None]:
# !unzip "/content/drive/My Drive/2DGAN.zip" -d "/content/drive/My Drive/2DGAN"

In [None]:
root_dir = '/content/drive/MyDrive/Pix2Pix'

In [None]:
import os
import sys
from torch.utils import data
from torch.autograd import Variable
import torch

sys.path.insert(0,root_dir)
sys.path.append(root_dir +"/src/")


import config

utils

In [None]:
import torch
from torchvision.utils import save_image

def save_some_examples(gen, val_loader, epoch, folder):

    for i, (x,y) in enumerate(val_loader):

        x, y = x.to(config.DEVICE), y.to(config.DEVICE)
        gen.eval()
        with torch.no_grad():
            y_fake = gen(x)
            y_fake = y_fake * 0.5 + 0.5  # remove normalization#
            save_image(y_fake, folder + f"/y_gen_{i}_{epoch}.png")
            save_image(x * 0.5 + 0.5, folder + f"/input_{i}_{epoch}.png")
            if epoch == 1:
                save_image(y * 0.5 + 0.5, folder + f"/label_{i}_{epoch}.png")
        gen.train()

        # Change the value of i to how many ever images you would like to view
        if i == 3:
          break


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr



dataset

In [None]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset
import config

class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        print(self.list_files)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]

        augmentations = config.both_transform(image=input_image, image0=target_image)
        input_image, target_image = augmentations["image"], augmentations["image0"]

        input_image = config.transform_only_input(image=input_image)["image"]
        target_image = config.transform_only_mask(image=target_image)["image"]

        return input_image, target_image



generator

In [None]:
import torch
import torch.nn as nn


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False)
        self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False)
        self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)
        self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)
        self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"), nn.ReLU()
        )

        self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
        self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
        self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
        self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))


def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)


if __name__ == "__main__":
    test()

torch.Size([1, 3, 256, 256])


discriminator

In [None]:
import torch
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels*2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        return self.model(x)



def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(preds.shape)


if __name__ == "__main__":
    test()

torch.Size([1, 1, 26, 26])


train

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import config
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt



def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x, y = x.to(config.DEVICE), y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_fake = disc(x, y_fake.detach())
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
    print(D_loss, G_loss)
    return D_loss, G_loss




# Train model

In [None]:
disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=3).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

D_LOSS_history = []
G_LOSS_history = []

if config.LOAD_MODEL:
    load_checkpoint(config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE)
    load_checkpoint(config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE)

train_dataset = MapDataset(root_dir="//content//drive//My Drive//Pix2Pix//data//maps//1000dataset")
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_dataset = MapDataset(root_dir="//content//drive//My Drive//Pix2Pix//data//maps//val")
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


for epoch in range(config.NUM_EPOCHS):
    D_LOSS, G_LOSS = train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)

    if config.SAVE_MODEL and epoch % 5 == 0:
        save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

    save_some_examples(gen, val_loader, epoch, folder="//content//drive//My Drive//Pix2Pix//data//maps//evaluation")

    D_LOSS_history.append(D_LOSS.cpu().detach().numpy())
    G_LOSS_history.append(G_LOSS.cpu().detach().numpy())

plt.plot(np.arange(config.NUM_EPOCHS), D_LOSS_history, 'bo')
plt.plot(np.arange(config.NUM_EPOCHS), G_LOSS_history, 'ro')

plt.legend(['D_loss','G_loss'], loc='upper right')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()


=> Loading checkpoint
=> Loading checkpoint
['43.jpg', '44.jpg', '48.jpg', '49.jpg', '50.jpg', '51.jpg', '52.jpg', '45.jpg', '46.jpg', '47.jpg', '53.jpg', '54.jpg', '55.jpg', '56.jpg', '61.jpg', '62.jpg', '63.jpg', '64.jpg', '65.jpg', '66.jpg', '67.jpg', '68.jpg', '69.jpg', '70.jpg', '71.jpg', '72.jpg', '73.jpg', '74.jpg', '75.jpg', '76.jpg', '34.jpg', '33.jpg', '36.jpg', '35.jpg', '29.jpg', '30.jpg', '32.jpg', '31.jpg', '25.jpg', '21.jpg', '17.jpg', '13.jpg', '9.jpg', '5.jpg', '1.jpg', '26.jpg', '27.jpg', '28.jpg', '22.jpg', '23.jpg', '24.jpg', '18.jpg', '19.jpg', '20.jpg', '14.jpg', '15.jpg', '16.jpg', '10.jpg', '11.jpg', '12.jpg', '6.jpg', '7.jpg', '8.jpg', '2.jpg', '3.jpg', '4.jpg', '81.jpg', '82.jpg', '83.jpg', '84.jpg', '200.jpg', '201.jpg', '202.jpg', '203.jpg', '204.jpg', '205.jpg', '206.jpg', '207.jpg', '208.jpg', '209.jpg', '210.jpg', '211.jpg', '212.jpg', '213.jpg', '214.jpg', '215.jpg', '216.jpg', '217.jpg', '218.jpg', '223.jpg', '222.jpg', '219.jpg', '220.jpg', '221.jpg', 

100%|██████████| 63/63 [00:42<00:00,  1.48it/s]


tensor(0.1502, device='cuda:0', grad_fn=<DivBackward0>) tensor(9.9686, device='cuda:0', grad_fn=<AddBackward0>)
=> Saving checkpoint
=> Saving checkpoint


 14%|█▍        | 9/63 [00:03<00:19,  2.79it/s]


KeyboardInterrupt: ignored

## Test here

In [None]:
class TestDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        print(self.list_files)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]

        augmentations = config.both_transform(image=input_image)
        input_image = augmentations["image"]

        input_image = config.transform_only_input(image=input_image)["image"]

        return input_image


def save_test_examples(gen, test_loader, epoch, folder):

    for i, x in enumerate(test_loader):

        x = x.to(config.DEVICE)
        gen.eval()
        with torch.no_grad():
            y_fake = gen(x)
            y_fake = y_fake * 0.5 + 0.5  # remove normalization#
            save_image(y_fake, folder + f"/y_gen_{i}_{epoch}.png")
            save_image(x * 0.5 + 0.5, folder + f"/input_{i}_{epoch}.png")

        gen.train()


disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=3).to(config.DEVICE)

if config.LOAD_MODEL:
    load_checkpoint(config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE)
    load_checkpoint(config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE)

test_dataset = TestDataset(root_dir="//content//drive//My Drive//Pix2Pix//data//maps//test")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Test model now
save_test_examples(gen, test_loader, epoch, folder="//content//drive//My Drive//Pix2Pix//data//maps//testresult")



=> Loading checkpoint
=> Loading checkpoint
['6.jpg', '5.jpg', '4.jpg', '2.jpg', '3.jpg', '1.jpg']
