In [None]:
# 3rd Party dependencies.
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms.v2 as transforms

from torch.utils.data import DataLoader
from tqdm import tqdm

# 1st Party dependencies.
from dataset.facades_dataset import FacadesDataset
from pix2pix.generator import Generator, GeneratorLoss
from pix2pix.discriminator import Discriminator, DiscriminatorLoss

%matplotlib inline

In [None]:
def visualise_models_comparison(pix2pix_model, cyclegan_model, images):
    figure, axes = plt.subplots(nrows=len(images), ncols=3, figsize=(10, 10))
    
    for i, idx in enumerate(indexes):
        image, label = dataset[idx]

        input = torch.tensor(np.expand_dims(label, 0))
        fake = generator_model(input)
        fake = fake.detach().numpy()

        fake = np.transpose(fake[0], (1,2,0))
        image = np.transpose(image.numpy(), (1,2,0))
        label = np.transpose(label.numpy(), (1,2,0))

        axes[i, 0].imshow(cv2.cvtColor(label, cv2.COLOR_BGR2RGB))
        axes[i, 1].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        axes[i, 2].imshow(cv2.cvtColor(fake, cv2.COLOR_BGR2RGB))

    column_names = [ 'Segmentation', 'Pix2Pix Generation', 'CycleGAN Generation' ]
    for i, ax in enumerate(axes.flatten()[:2]):
        ax.axis('off')
        ax.set_title(column_names[i], fontweight='bold')

    fig.subplots_adjust(hspace=0.5, bottom=0.1)
    plt.tight_layout(pad=1)
    plt.show()    