## You can test and visualize our model using this notebook. You should run each cell in sequence.

In [None]:
import torch
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
import numpy as np
from model_D import Discriminator
from model_G import Generator
from model_R import Regressor
from PIL import Image
import os
from torch.utils.data import Dataset
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
disc_C = Discriminator(in_channels=3).to(config.DEVICE)
gen_C = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
reg_C = Regressor().to(config.DEVICE)
reg_C.load_state_dict(torch.load('cartoon_torch.pt'))
reg_H = Regressor().to(config.DEVICE)
reg_H.load_state_dict(torch.load('human_torch.pt')) 
opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_C.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

opt_gen = optim.Adam(
        list(gen_C.parameters()) + list(gen_H.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )
load_checkpoint(
            config.CHECKPOINT_G_H, gen_H, opt_gen, config.LEARNING_RATE,
        )
load_checkpoint(
            config.CHECKPOINT_G_C, gen_C, opt_gen, config.LEARNING_RATE,
        )


In [None]:
class Test(Dataset):
    def __init__(self, root_human, root_blond, transform=None):
        self.root_blond = root_blond
        self.root_human = root_human
        self.transform = transform
    
        self.blond_images = os.listdir(root_blond)
        self.human_images = os.listdir(root_human)
        
        self.length_dataset = max(len(self.blond_images), len(self.human_images)) # 1000, 1500
        self.blond_len = len(self.blond_images)
        self.human_len = len(self.human_images)
        
    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        
        blond_name = self.blond_images[index % self.blond_len]
        blond_path = os.path.join(self.root_blond, blond_name)
        
        human_name = self.human_images[index % self.human_len]
        human_path = os.path.join(self.root_human, human_name)
        
        blond_img = np.array(Image.open(blond_path).convert("RGB"))
        human_img = np.array(Image.open(human_path).convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image=blond_img, image0=human_img)
            blond_img = augmentations["image"]
            human_img = augmentations["image0"]

        return human_img, blond_img, human_name, blond_name

In [None]:
# root_human contains the input images selected from CelebA dataset
# root_blond contains the images generated by StarGAN with the images from root_human as inputs. 
val_dataset = Test(
        root_human="/home/zluan/ECE_228_project/CycleGAN/stargan_input", 
        root_blond="/home/zluan/ECE_228_project/CycleGAN/stargan_output", 
        transform=config.transforms
    )
    
val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )

In [None]:
# Load the original input images and the blond human images
loop = tqdm(loader, leave=True)
star_im_out = []
star_im_out_name = []
star_im_in=[]
star_im_in_name = []

for idx, (human, blond,human_name,blond_name) in enumerate(loop):
        star_out = blond.to(config.DEVICE)
        star_out_name = blond_name
        star_in = human.to(config.DEVICE)
        star_in_name = human_name

        
        star_im_out.append(star_out)
        star_im_out_name.append(star_out_name)
        star_im_in.append(star_in)
        star_im_in_name.append(star_in_name)
        

In [None]:
i = 0

##### You can visualize the result here. You are expected to see some decent results and some more distortion results like what we mentioned in our report that our results are affected by many factors (like complicated background, eyeglasses, darker skin tone and etc.).
##### The first row being the output of our model and the input blond human image.
##### The second row being the output of our model and the input human image from ClebA dataset. 
##### Feel free to run the cell below multiple times to check out different results.

In [None]:
i = i+1
print(i)
print(star_im_out_name[i])
star_out_out = star_im_out[i]
star_in_out = star_im_in[i]
star_image_out = gen_C(torch.squeeze(star_out_out).cuda()*0.5+0.5)
star_image_out_initial = torch.squeeze(star_out_out).cuda()*0.5+0.5
star_image_in = gen_C(torch.squeeze(star_in_out).cuda()*0.5+0.5)
star_image_in_initial = torch.squeeze(star_in_out).cuda()*0.5+0.5

star_out_plt = np.transpose(star_image_out.cpu().detach().numpy(), (1, 2, 0))
star_out_initial_plt = np.transpose(star_image_out_initial.cpu().detach().numpy(), (1, 2, 0))
star_in_plt = np.transpose(star_image_in.cpu().detach().numpy(), (1, 2, 0))
star_in_initial_plt = np.transpose(star_image_in_initial.cpu().detach().numpy(), (1, 2, 0))
plt.subplot(2,2,1)
plt.imshow(star_out_plt*0.5+0.5)
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(star_out_initial_plt)
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(star_in_plt*0.5+0.5)
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(star_in_initial_plt)
plt.axis('off')

