# Use a pretrained model to do inferences

## Model and Dataset Selection

Select a pretrained model path and eventually a dataset and execute all cells of the notebook. Generated images will be placed in `./inferences/<model-name>/<dataset-name>/`

**Type of Model Selection**

In [None]:
# -- Takes noise vectors as input -- #

# model = "gan"
# model = "dcgan"

# -- Takes images as input -- #

model = "fdcgan"
# model = "resnet"

**Dataset Selection**

In [None]:
# -- Only adapted to noise vectors input -- #

# dataset = "mnist"

# -- Adapted to both noise vectors and images input -- #

dataset = "apple2orange64"
# dataset = "orange2apple64"

**Model Path**

In [None]:
import os

# generator_name = "generator_199" # no cycle-gan
generator_name = "G_AB_136" # cycle-gan
# generator_name = "G_BA_136" # cycle-gan

# is_trained_with_cycleGAN = False
is_trained_with_cycleGAN = True

path_model = model
if is_trained_with_cycleGAN:
    path_model = "cycle-gan/" + model

model_path = "saved_models/" + path_model + "/" + dataset + "/" + generator_name + ".pth"

# Make sure that model exists

if not os.path.isfile(model_path):
    print("Model does not exist")
    exit()
else:
    print("Loaded model", model_path)

## Imports

In [None]:
import sys
import datetime
import time
import glob
import random

import numpy as np

from PIL import Image

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch

## Hyperparameters and Datasets Specific Configuration

**Global Configuration**

In [None]:
n_cpu = 8               # number of cpu threads to use during batch generation
batch_size = 1

**Datasets Specifications**

In [None]:
if dataset == "mnist":
    channels = 1
    latent_dim = 100
    img_size = 64
    
elif dataset == "apple2orange64" or dataset == "orange2apple64":
    channels = 3
    latent_dim = 300
    img_size = 64

else:
    raise Exception("Unknown dataset")

**Importing Chosen Models**

In [None]:
if model == "gan":
    from gan import Generator
    
elif model == "dcgan":
    from dcgan import Generator

elif model == "fdcgan" or model == "resnet":
    if dataset not in ["apple2orange64", "orange2apple64"]:
        raise Exception(f"Dataset {dataset} has no input image for the generator")
    if model == "fdcgan":
        from fdcgan import Generator
    elif model == "resnet":
        from resnet import Generator
        n_residual_blocks = 9

else:
    raise Exception("Unknown model")

img_shape = (channels, img_size, img_size)

**Setting Up Cuda**

In [None]:
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if cuda else "cpu")
print("Device use for training: ", device)

## Initialising Models

In [None]:
# -- Initialize generator -- #

if model == "gan" or model == "dcgan":
    generator = Generator(img_shape, latent_dim)
elif model == "fdcgan":
    generator = Generator(img_shape)
elif model == "resnet":
    generator = Generator(img_shape, n_residual_blocks)

# Load pretrained models
generator.load_state_dict(torch.load(model_path))

if cuda:
    generator.cuda()

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## Importing Dataset

In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

transforms_ = [
    transforms.Resize(int(img_size * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

if dataset == "mnist":
    
    os.makedirs("./datasets/mnist", exist_ok=True)

    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./datasets/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize(img_size)]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_cpu,
    )

if dataset == "apple2orange64":
    my_class_A = "A"
    my_class_B = "B"

if dataset == "orange2apple64":
    my_class_A = "B"
    my_class_B = "A"

if dataset == "apple2orange64" or dataset == "orange2apple64":

    import subprocess
    command = "bash ./datasets/download_cyclegan_dataset.sh"
    subprocess.run(command, shell=True)

    # Test data loader
    dataloader = DataLoader(
        ImageDataset("./datasets/apple2orange64", transforms_=transforms_, unaligned=True, mode="validation"),
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_cpu,
    )

## Inference

Unfortunately depending on how the dataset is imported we need to do separated training loops for mnist and apple2orange but the procedure is exactly the same.

In [None]:
image_inference_folder = "inferences/%s/%s/%s" % (path_model, dataset, my_class_B)
os.makedirs(image_inference_folder, exist_ok=True)

print(f"Using '{model}' with '{dataset}', saving results to '{image_inference_folder}'.")

### Random vector noise input

**mnist:**

In [None]:
# if dataset == "mnist":
#     prev_time = time.time()
#     for i, (imgs, _) in enumerate(dataloader):
        
#         if (imgs.shape[0] != batch_size):
#             continue

#         # Configure input
#         real_imgs = Variable(imgs.type(Tensor))

#         # Sample noise as generator input
#         z = Variable(Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim))))

#         # Generate a batch of images
#         gen_imgs = generator(z)
        
#         # Loss measures generator's ability to fool the discriminator
#         g_loss = adversarial_loss(discriminator(gen_imgs), valid)

#         g_loss.backward()
#         optimizer_G.step()

#         # ---------------------
#         #  Train Discriminator
#         # ---------------------

#         optimizer_D.zero_grad()

#         # Measure discriminator's ability to classify real from generated samples
#         real_loss = adversarial_loss(discriminator(real_imgs), valid)
#         fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
#         d_loss = (real_loss + fake_loss) / 2

#         d_loss.backward()
#         optimizer_D.step()

#         # --------------
#         #  Log Progress
#         # --------------
        
#         # Determine approximate time left
#         batches_done = epoch * len(dataloader) + i
#         batches_left = n_epochs * len(dataloader) - batches_done
#         time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
#         prev_time = time.time()

#         sys.stdout.write(
#             "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
#             % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), time_left))

#         batches_done = epoch * len(dataloader) + i
#         if batches_done % sample_interval == 0:
#             save_image(gen_imgs.data[:batch_size], image_progress_folder + "/%d.png" % batches_done, nrow=5, normalize=True)
    
#     if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
#         # Save model checkpoints
#         torch.save(generator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
#         torch.save(discriminator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
                

**orange2apple or apple2orange with gan or dcgan:**

In [None]:
# if (dataset == "orange2apple64" or dataset == "apple2orange64") and (model == "gan" or model == "dcgan"):
#     prev_time = time.time()
#     for epoch in range(start_epoch, n_epochs):
#         for i, batch in enumerate(dataloader):

#             # Set model input
#             real_imgs = Variable(batch[my_class_B].type(Tensor))
            
#             if (real_imgs.shape[0] != batch_size):
#                 continue

#             # Adversarial ground truths
#             valid = Variable(Tensor(np.ones((real_imgs.size(0), *discriminator.output_shape))), requires_grad=False)
#             fake = Variable(Tensor(np.zeros((real_imgs.size(0), *discriminator.output_shape))), requires_grad=False)

#             # -----------------
#             #  Train Generator
#             # -----------------

#             optimizer_G.zero_grad()

#             # Sample noise as generator input
#             z = Variable(Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim))))

#             # Generate a batch of images
#             gen_imgs = generator(z)
            
#             # Loss measures generator's ability to fool the discriminator
#             g_loss = adversarial_loss(discriminator(gen_imgs), valid)

#             g_loss.backward()
#             optimizer_G.step()

#             # ---------------------
#             #  Train Discriminator
#             # ---------------------

#             optimizer_D.zero_grad()

#             # Measure discriminator's ability to classify real from generated samples
#             real_loss = adversarial_loss(discriminator(real_imgs), valid)
#             fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
#             d_loss = (real_loss + fake_loss) / 2

#             d_loss.backward()
#             optimizer_D.step()

#             # --------------
#             #  Log Progress
#             # --------------
            
#             # Determine approximate time left
#             batches_done = epoch * len(dataloader) + i
#             batches_left = n_epochs * len(dataloader) - batches_done
#             time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
#             prev_time = time.time()

#             sys.stdout.write(
#                 "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
#                 % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), time_left))

#             batches_done = epoch * len(dataloader) + i
#             if batches_done % sample_interval == 0:
#                 save_image(gen_imgs.data[:batch_size], image_progress_folder + "/%d.png" % batches_done, nrow=5, normalize=True)
        
#         if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
#             # Save model checkpoints
#             torch.save(generator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
#             torch.save(discriminator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
                    

### Image input

**orange2apple or apple2orange with fdcgan or resnet (improved fdgcan):**

In [None]:
if (dataset == "orange2apple64" or dataset == "apple2orange64") and (model == "fdcgan" or model == "resnet"):
    prev_time = time.time()
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch[my_class_A].type(Tensor))

        # ------------------
        #  Use Generators
        # ------------------

        generator.eval()

        fake_B = generator(real_A)

        # --------------
        #  Log Progress
        # --------------

        # # Determine approximate time left
        batches_done = len(dataloader) + i
        batches_left = len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Batch %d/%d] ETA: %s"
            % (
                i,
                len(dataloader),
                time_left,
            )
        )

        # --------------
        #  Save Image
        # --------------

        image_name = image_inference_folder + "/%d.png" % i
        save_image(fake_B, image_name, normalize=True)
        
