# Test pipeline

### Definições iniciais

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.linalg import sqrtm
from tqdm import trange
import random
random.seed(5)
import matplotlib.pyplot as plt
import numpy as np

In [12]:
from datasets import lungCTData
from model import Generator
from metrics import calculate_fid, FeatureExtractor, get_features
from metrics import my_ssim, crop_center_array

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

### Dados

In [4]:
# Ajustar isso aqui para pasta com dados de teste
processed_data_folder = '/mnt/shared/ctdata_thr25'
#path to model
trained_gen_path = './model_thr_25k/models/model_thr_25k_gen_trained.pt'

In [5]:
dataset_test = lungCTData(processed_data_folder=processed_data_folder,mode='test',start=20000,end=21000)

In [6]:
gen = Generator()
gen.load_state_dict(torch.load(trained_gen_path, weights_only=True))
gen.to(device)
gen.eval()
batch_size = 10
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

# Calcula o FID

In [7]:
# Define a rede Inception V3
model_inception = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
model_inception.eval()

# Define modelo para feature extraction
feature_extractor = FeatureExtractor(model_inception)

Using cache found in /home/leticia/.cache/torch/hub/pytorch_vision_v0.10.0


In [8]:
#gera dados sintéticos
fake_data_imgs = np.empty((len(dataset_test),512,512))
real_data_imgs = np.empty((len(dataset_test),512,512))
fake_data_features = np.empty((len(dataset_test),2048))
real_data_features = np.empty((len(dataset_test),2048))
counter = 0
with torch.no_grad():
    for batch in data_loader_test:
        input_img_batch = batch[0]
        input_mask_batch = batch[1]
        
        input_img = input_img_batch.to(device)
        input_mask = input_mask_batch.to(device)
        
        gen_img = gen(input_mask)
        fake_data_imgs[counter*batch_size:(counter+1)*batch_size,:,:] = np.squeeze(gen_img.detach().cpu().numpy())
        real_data_imgs[counter*batch_size:(counter+1)*batch_size,:,:] = np.squeeze(input_img.detach().cpu().numpy())

        features_fake = get_features(feature_extractor, gen_img, choose_transform=2, device=device)
        features_real = get_features(feature_extractor, input_img, choose_transform=2, device=device)
        fake_data_features[counter*batch_size:(counter+1)*batch_size,:] = features_fake
        real_data_features[counter*batch_size:(counter+1)*batch_size,:] = features_real
        counter=counter+1

In [9]:
# Obtém as distribuições para os dados reais e sintéticos
mu1, sigma1 = np.mean(np.squeeze(real_data_features),axis=0), np.cov(np.squeeze(real_data_features))
mu2, sigma2 = np.mean(np.squeeze(fake_data_features),axis=0), np.cov(np.squeeze(fake_data_features))

# Calcula o FID entre os dados reais e sintéticos
fid_value = calculate_fid(mu1, sigma1, mu2, sigma2)
print(f"FID: {fid_value:.2f}")

FID: 70.70


# Calcula o SSIM

In [10]:
# Calcula SSIM entre dados reais e sintéticos
ssim = np.zeros(len(dataset_test))
luminance = np.zeros(len(dataset_test))
contraste = np.zeros(len(dataset_test))
struct_similarity = np.zeros(len(dataset_test))
for i in range(len(dataset_test)):
    ssim[i], luminance[i], contraste[i], struct_similarity[i] = my_ssim(np.squeeze(real_data_imgs[i]), np.squeeze(fake_data_imgs[i]), 0.01, 0.03, 0.03)

print(np.mean(ssim), np.std(ssim))

0.9198216603737535 0.05271807711304367


In [15]:
# Calcula SSIM entre dados reais e sintéticos
ssim = np.zeros(len(dataset_test))
luminance = np.zeros(len(dataset_test))
contraste = np.zeros(len(dataset_test))
struct_similarity = np.zeros(len(dataset_test))
for i in range(len(dataset_test)):
    ssim[i], luminance[i], contraste[i], struct_similarity[i] = my_ssim(np.squeeze(crop_center_array(real_data_imgs[i], 256, 256)), np.squeeze(crop_center_array(fake_data_imgs[i], 256, 256)), 0.01, 0.03, 0.03)

print(np.mean(ssim), np.std(ssim))

0.9271554151377082 0.05912170287379875


# 