# Test pipeline

### Definições iniciais

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.linalg import sqrtm
from tqdm import trange
import random
random.seed(5)
import matplotlib.pyplot as plt
import numpy as np

In [12]:
from datasets import lungCTData
from model import Generator
from metrics import my_fid_pipeline, my_ssim_pipeline

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

### Dados

In [4]:
# Ajustar isso aqui para pasta com dados de teste
processed_data_folder = '/mnt/shared/ctdata_thr25'
#path to model
trained_gen_path = './model_thr_25k/models/model_thr_25k_gen_trained.pt'

In [5]:
dataset_test = lungCTData(processed_data_folder=processed_data_folder,mode='test',start=20000,end=21000)

In [6]:
gen = Generator()
gen.load_state_dict(torch.load(trained_gen_path, weights_only=True))
gen.to(device)
gen.eval()
batch_size = 10
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

# Calcula o FID

In [8]:
fid_value = my_fid_pipeline(dataset_test, data_loader_test, device, gen, batch_size)
print(fid_value)

# Calcula o SSIM

In [10]:
# Imagem completa
ssim, luminance, contraste, struct_similarity = my_ssim_pipeline(dataset_test, data_loader_test, device, gen, batch_size, bComplete=True)
print(np.mean(ssim), np.std(ssim))

0.9198216603737535 0.05271807711304367


In [15]:
# Imagem parcial (apenas centro)
ssim, luminance, contraste, struct_similarity = my_ssim_pipeline(dataset_test, data_loader_test, device, gen, batch_size, bComplete=False)
print(np.mean(ssim), np.std(ssim))

0.9271554151377082 0.05912170287379875


# 