In [1]:
import json
import os

import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader

from data_utils import get_dataset_path, unnormalize
from data_utils.BrainDataset import BrainDataset
from data_utils.BrainSampler import BrainSampler
from gan import GAN

from torchmetrics.image.fid import FrechetInceptionDistance

from skimage.metrics import *

In [2]:
device = torch.device("cuda")
model_name = "8006146278150"
model_dict = torch.load(os.path.join(get_dataset_path(), f"../models", f"{model_name}/model.pt"), weights_only=True)
options = json.load(open(os.path.join(get_dataset_path(), "../models", f"{model_name}/options.json"), "r"))
model = GAN(options["tile_size"], 3, options["latent_dim"])
model.load_state_dict(model_dict)
model = model.eval()
model = model.to(device)

data = BrainDataset([(1947, 97, None), (1947, 160, "cerebellum")], options["map_type"], options["resolution"])
sampler = BrainSampler(data, tile_size=options["tile_size"], map_type=options["map_type"])
loader = DataLoader(data, sampler=sampler, batch_size=1000)
data_iter = iter(loader)


In [3]:
torch.manual_seed(0)
fid = FrechetInceptionDistance(feature=64)
real_images = next(data_iter)[0].to(device)
fake_images = model.sample(1000)
for real, fake in zip(real_images, fake_images):
    r, f = real.unsqueeze(0), fake.unsqueeze(0)
    r = torch.tensor(unnormalize(r.permute(0,2,3,1).cpu(), options["map_type"])).permute(0,3,1,2)
    f = torch.tensor(unnormalize(f.permute(0,2,3,1).cpu(), options["map_type"])).permute(0,3,1,2)
    fid.update(r, real=True)
    fid.update(f, real=False)
    del r,f
print(fid.compute())


tensor(4.1811)


In [4]:
model_name = "8180058912588"
device = torch.device("cuda")
model_dict = torch.load(os.path.join(get_dataset_path(), f"../models", f"{model_name}/model.pt"), weights_only=True)
options = json.load(open(os.path.join(get_dataset_path(), "../models", f"{model_name}/options.json"), "r"))
model = GAN(options["tile_size"], 3, options["latent_dim"], options["conditional"], options["condition_dim"])
model.load_state_dict(model_dict)
model = model.eval()
model = model.to(device)

data = BrainDataset(options["training_brains"], options["map_type"], options["resolution"])
sampler = BrainSampler(data, tile_size=options["tile_size"], map_type=options["map_type"])
loader = DataLoader(data, sampler=sampler, batch_size=1000)
data_iter = iter(loader)



In [5]:
torch.random.manual_seed(0)
trans, fom = next(data_iter)
trans = trans.to(device)
samples = model.sample(trans.shape[0], trans)
mse = []
ssim = []
psnr = []
for real, fake in zip(fom, samples):
    r = unnormalize(real.unsqueeze(0).permute(0,2,3,1).cpu(), "fom")
    f = unnormalize(fake.unsqueeze(0).permute(0,2,3,1).cpu(), "fom")
    mse.append(mean_squared_error(r, f))
    ssim.append(structural_similarity(r[0], f[0], multichannel=True, channel_axis=2))
    psnr.append(peak_signal_noise_ratio(r[0], f[0]))
print(f"Mean MSE: {np.mean(mse)} STD MSE: {np.std(mse)}")
print(f"Mean SSIM: {np.mean(ssim)} STD SSIM: {np.std(ssim)}")
print(f"Mean PSNR: {np.mean(psnr)} STD PSNR: {np.std(psnr)}")
del samples, trans, fom


Mean MSE: 1340.8286918131512 STD MSE: 1903.379041055308
Mean SSIM: 0.44578721369522356 STD SSIM: 0.3520543266763236
Mean PSNR: 27.68133953737192 STD PSNR: 14.951478892899537
