In [90]:
import os
import sys

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

sys.path.append("../")
from data_preparation import CustomDataset
from networks import UNet, Critic
from metrics import FID, psnr
from train import Trainer
from utils import Params

In [91]:
data_dir = "../data/val/"
model_dir = "../models/checkpoints"
checkpoint = "231119195749"

In [92]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [93]:
transform = transforms.Compose([transforms.ToTensor()])
dataset = CustomDataset(data_dir, transform)
dataloader = DataLoader(
    dataset, batch_size=32, shuffle=True, drop_last=True
)

In [94]:
config_path = os.path.join(model_dir, checkpoint, "hyper_params/params.json")
config = Params(config_path)

In [95]:
gen = UNet(config.INPUT_DIM, config.REAL_DIM)
global_critic = Critic(
    config.INPUT_DIM + config.REAL_DIM, config.GLOBAL_CRITIC_NUM_DOWN_BLOCKS
)
local_critic = Critic(
    config.INPUT_DIM + config.REAL_DIM, config.LOCAL_CRITIC_NUM_DOWN_BLOCKS
)

In [96]:
recon_criterion = nn.L1Loss()

In [97]:
metrics = {"fid": FID(device, config), "psnr": psnr}

In [98]:
trainer = Trainer(gen, global_critic, local_critic, None, None, None, recon_criterion, metrics, None, config, model_dir, device, restore_version=checkpoint)

In [99]:
val_metrics = trainer.evaluate(dataloader)

In [100]:
print(f"The FID of validation data is {val_metrics['fid']:.2f}")

The FID of validation data is 80.71
