In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import os
import datetime
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn

from torch import optim
from torch.utils.data import Dataset, DataLoader
import sys 
sys.path.append("..") 
from model.dataset import *
from model.GaussianDiffusion_new1228 import GaussianDiffusion #可以自己调

from metrics_calculate import *

看生成图片，定性描述

In [2]:
device = 'cuda:1'
model_path = "./DiffusionSpinalMRISynthesis/model_save/diffusion_model.pth"
net_G = torch.load(model_path, map_location=device)['model'] #['model']

# # 这样可以固定随机数种子
# net_G = GaussianDiffusion(
#     image_size=256,
#     model = net_G.model,
#     timesteps = 1000,           # number of steps
#     sampling_timesteps = 16,    # using ddim for faster inference 
#     objective = 'pred_v', #pred_v pred_x0崩 pred_noise

# ).to(device)
net_G.eval()
print("模型已加载")

spinal_test_dir = "./DiffusionSpinalMRISynthesis/Data_MRI/test_spinal_MRI"
test_spinal_dataset = MRI_patient_Dataset_fortest(dir_path=spinal_test_dir)
test_spinal_dataloader = DataLoader(test_spinal_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)

模型已加载


需要定量描述，把所有测试集的图片测试一遍，计算指标！

In [3]:
def get_data_from_batch(batch_data: torch.Tensor, device= device):
    return (batch_data[:,:3,:,:].to(device), batch_data[:,3:,:,:].cpu().detach())

def Test_model_metrics(net_G:GaussianDiffusion, test_dataloader):
    df_metrics = {'AUC':[],'MSE':[], 'SSIM':[], 'PSNR':[]} #
    # df_metrics = []
    net_G.eval()
    for idx, batch_data in enumerate(tqdm(test_dataloader), 1):
        input_img, target_img = get_data_from_batch(batch_data)
        with torch.no_grad():
            fake_img = net_G.ddim_sample(shape =target_img.shape, source_img = input_img.to(device), 
                sampling_timesteps = 16, random_seed=42).squeeze(dim=1).detach().cpu().numpy()
            # fake_img = nn.Tanh()(fake_img).detach().cpu().numpy()
        # 在-1+1 maxminnorm计算全局指标
        AUC, MSE, SSIM, PSNR= cal_metric_list(fake_img, target_img.squeeze(dim=1).numpy(), method='8bit', norm=True)
        df_metrics['AUC'].extend(AUC)
        df_metrics['MSE'].extend(MSE)
        df_metrics['SSIM'].extend(SSIM)
        df_metrics['PSNR'].extend(PSNR)
        # df_metrics['NRMSE'].extend(NRMSE)
        
    df_metrics = pd.DataFrame(df_metrics)
    print(df_metrics.mean())

    return df_metrics

In [None]:
metrics_df = Test_model_metrics(net_G, test_spinal_dataloader)
metrics_df

In [5]:
metrics_df.to_csv("./DiffusionSpinalMRISynthesis/Test_models/metrics/diffusion_test.csv")

In [None]:
metrics_df = pd.read_csv("./DiffusionSpinalMRISynthesis/Test_models/metrics/diffusion_test.csv").drop("Unnamed: 0", axis=1)
df_metrics_sum = pd.concat([metrics_df.mean(axis=0), metrics_df.std(axis=0)], axis=1).rename(columns={0:'metrics_mean', 1:'metrics_std'})
df_metrics_sum.to_csv("./DiffusionSpinalMRISynthesis/Test_models/metrics/diffusion_1231_randomseed42_test_sum.csv")
df_metrics_sum

In [7]:
spinal_test_dir = "./DiffusionSpinalMRISynthesis/Data_MRI/test_spinal_MRI"
output_path = "./DiffusionSpinalMRISynthesis/results_output_pictures/diffusion_500epoch_test_spinal_16ddim"

def Test_model_onepatient(net_G:GaussianDiffusion, patient_id, output_path=output_path):
    os.makedirs(output_path, exist_ok=True)
    # os.makedirs(output_npy_path, exist_ok=True)
    fnames = glob(os.path.join(spinal_test_dir, '*'+patient_id+'*'))
    test_onepatient_dataset = MRI_patient_Dataset_fortestpatient(dir_path=spinal_test_dir, patient_id=patient_id)
    net_G.eval()
    input_img = torch.concat([test_onepatient_dataset[i][[0,1,2],:,:].unsqueeze(dim=0) for i in range(len(test_onepatient_dataset))],dim=0)
    target_img = torch.concat([test_onepatient_dataset[i][3:,:,:].unsqueeze(dim=0) for i in range(len(test_onepatient_dataset))],dim=0)
    with torch.no_grad():
        fake_img = net_G.ddim_sample(shape =target_img.shape, source_img = input_img.to(device), 
            sampling_timesteps = 16, random_seed=666).squeeze(dim=1)
    fake_img = nn.Tanh()(fake_img).detach().cpu().numpy()
    fake_img_norm = norm_layerHW(fake_img, method='8bit')
    # target_img = norm(target_img.detach().cpu().squeeze(dim=1).numpy())
    # output_img = np.concatenate([target_img, fake_img], axis=-1)
    patient_save_path = os.path.join(output_path, patient_id)
    # patient_npy_save_path = os.path.join(output_npy_path, patient_id)
    os.makedirs(patient_save_path, exist_ok=True)
    # os.makedirs(patient_npy_save_path, exist_ok=True)
    for i in range(fake_img.shape[0]):
        img = Image.fromarray(fake_img_norm[i,:,:])
        i_savepath = os.path.join(patient_save_path, patient_id+'_T1CE_pred_'+fnames[i].split('_')[-1])
        img.save(i_savepath)



In [None]:
patient_id_list = list(set([path.split('_')[0]+'_'+path.split('_')[1]+'_'+path.split('_')[2] for path in os.listdir(spinal_test_dir)]))
for patient_id in tqdm(patient_id_list):
    Test_model_onepatient(net_G = net_G, patient_id = patient_id)