In [None]:
import os
import glob
import random

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from itertools import chain

In [None]:
# model_name = "resnet"
# model_dir_name = "resnet"
# model_dir_name = "restnet_no_cycle_consistency"

model_name = "fdcgan"
model_dir_name = "fdcgan"
# model_dir_name = "fdcgan_no_cycle_consistency"

dataset_name = "apple2orange64"

In [None]:
# --- Extract data from local files

class ExampleLoader(Dataset):
    def __init__(self, root, img_size, transformations):
        if not os.path.isdir(root):
            raise FileNotFoundError(f"Directory not found : {root}")
        self.root = root
        self.img_size = img_size
        self.transform = transforms.Compose(transformations)

    def get_images(self, image_relative_paths):
        batch = torch.zeros(
            (len(image_relative_paths), 3, self.img_size, self.img_size),
            dtype=torch.float32
        )
        for i, img_rel_path in enumerate(image_relative_paths):
            image_path = os.path.join(self.root, img_rel_path)
            img_tensor = self.transform(Image.open(image_path))
            batch[i] = img_tensor
        return batch


data_path = os.path.join(".", "datasets", dataset_name)
img_size = 64
img_shape = (3, img_size, img_size)

transformations = [
    transforms.Resize(int(img_size * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_size, img_size)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

example_loader = ExampleLoader(data_path, img_size, transformations)

def tensor_to_image(tensor_image):
    image = tensor_image.detach().to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    return image

def show_examples(loader, image_relative_paths):
    fig, axs = plt.subplots(1, len(image_relative_paths), figsize=(15, 6))
    for i, img_rel_path in enumerate(image_relative_paths):
        img = loader.get_images([img_rel_path])[0]
        ax = axs[i]
        ax.set_title(img_rel_path)
        ax.imshow(tensor_to_image(img))
        ax.axis("off")
    plt.show()

show_examples(example_loader, (
    "validation/A/n07740461_4610.jpg",
    "validation/A/n07740461_14721.jpg",
    "validation/A/n07740461_13851.jpg",
))

In [None]:
# --- Load the pretrained CycleGan generator

model_save_dir = os.path.join(".", "saved_models", "cycle-gan", model_dir_name, dataset_name)

if model_name == "fdcgan":
    from fdcgan import Generator
    g_ab = Generator(img_shape)
    g_ba = Generator(img_shape)

elif model_name == "resnet":
    from resnet import Generator
    g_ab = Generator(img_shape, num_residual_blocks=9)
    g_ba = Generator(img_shape, num_residual_blocks=9)

else:
    raise ValueError("Invalid model name")

g_ab_save_path = os.path.join(model_save_dir, "G_AB_199.pth")
g_ba_save_path = os.path.join(model_save_dir, "G_BA_199.pth")

if not os.path.exists(g_ab_save_path):
    raise FileNotFoundError(f"Could not find the pretrained model: {g_ab_save_path}")
elif not os.path.exists(g_ba_save_path):
    raise FileNotFoundError(f"Could not find the pretrained model: {g_ba_save_path}")
else:
    print(f"Loading pretrained models:\nA->B: {g_ab_save_path}\nB->A: {g_ba_save_path}")

g_ab.load_state_dict(torch.load(g_ab_save_path))
g_ba.load_state_dict(torch.load(g_ba_save_path))

g_ab.eval()
g_ba.eval()

In [None]:
original_a = example_loader.get_images((
    "validation/A/n07740461_4610.jpg",
    "validation/A/n07740461_14721.jpg",
    "validation/A/n07740461_13851.jpg",
))

with torch.no_grad():
    generated_b = g_ab(original_a)
    recovered_a = g_ba(generated_b)

for i in range(len(original_a)):
    fig, axs = plt.subplots(1, 3, figsize=(15, 6))
    axs[0].set_title("Original A")
    axs[0].imshow(tensor_to_image(original_a[i]))
    axs[0].axis("off")

    axs[1].set_title("Generated B")
    axs[1].imshow(tensor_to_image(generated_b[i]))
    axs[1].axis("off")

    axs[2].set_title("Recovered A")
    axs[2].imshow(tensor_to_image(recovered_a[i]))
    axs[2].axis("off")

    plt.show()
