In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import os
import datetime
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn

from torch import optim
from torch.utils.data import Dataset, DataLoader
import sys 
sys.path.append("..") 
from model.dataset import *
from model.GaussianDiffusion_new1228 import GaussianDiffusion #可以自己调

from metrics_calculate import *

看生成图片，定性描述

In [2]:
device = 'cuda:0'
model_path = "./DiffusionSpinalMRISynthesis/model_save/diffusion_model_mask_channels.pth"
net_G = torch.load(model_path, map_location=device)['model']
# 这样可以固定随机数种子
net_G = GaussianDiffusion(
    image_size=256,
    model = net_G.model,
    timesteps = 1000,           # number of steps
    sampling_timesteps = 16,    # using ddim for faster inference 
    objective = 'pred_v', #pred_v pred_x0崩 pred_noise

).to(device)
net_G.eval()
print("模型已加载")

模型已加载


In [3]:
spinal_test_dir = "./DiffusionSpinalMRISynthesis//Data_MRI/test_spinal_MRI"
test_spinal_dataset = MRI_patient_Dataset_fortest(dir_path=spinal_test_dir)
test_spinal_dataloader = DataLoader(test_spinal_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)

def get_data_from_batch(batch_data: torch.Tensor, input_state, device= device):
    img_input, img_target = batch_data[:,:3,:,:], batch_data[:,3:,:,:].cpu().detach()
    if input_state == 0: #T1 T2 T2FS
        pass
    elif input_state == 1: #T1 T2
        img_input[:,2,:,:] = 0
    elif input_state == 2: #T1 T2FS
        img_input[:,1,:,:] = 0
    elif input_state == 3: #T2 T2FS
        img_input[:,0,:,:] = 0
    elif input_state == 4: #T1
        img_input[:,1,:,:] = 0
        img_input[:,2,:,:] = 0
    elif input_state == 5: #T2
        img_input[:,0,:,:] = 0
        img_input[:,2,:,:] = 0
    elif input_state == 6: #T2FS
        img_input[:,0,:,:] = 0
        img_input[:,1,:,:] = 0

    return (img_input.to(device), img_target)

def Test_model_metrics(net_G:GaussianDiffusion, test_dataloader, input_state):
    MSE_imgs, SSIM_imgs, PSNR_imgs, NRMSE_imgs = [], [], [], []
    print(f"输入状态：{input_state}")
    net_G.eval()
    for idx, batch_data in enumerate(test_dataloader, 1):
        input_img, target_img = get_data_from_batch(batch_data, input_state)
        with torch.no_grad():
            fake_img = net_G.ddim_sample(shape =target_img.shape, 
                source_img = input_img.to(device), sampling_timesteps = 16).detach().cpu().numpy()
        mse_list, ssim_list, psnr_list,  nrmse_list = cal_metric(fake_img, target_img.cpu().detach().numpy())
        MSE_imgs.extend(mse_list)
        SSIM_imgs.extend(ssim_list)
        PSNR_imgs.extend(psnr_list)
        NRMSE_imgs.extend(nrmse_list)
    print(f"MSE_imgs: {np.mean(MSE_imgs)}, SSIM_imgs: {np.mean(SSIM_imgs)}, PSNR_imgs: {np.mean(PSNR_imgs)}, NRMSE_imgs: {np.mean(NRMSE_imgs)}")
    metrics_mean_df = pd.DataFrame({"MSE_imgs": [np.mean(MSE_imgs)], "SSIM_imgs": [np.mean(SSIM_imgs)],\
     "PSNR_imgs": [np.mean(PSNR_imgs)], "NRMSE_imgs": [np.mean(NRMSE_imgs)]}, index=[f"input_state_{input_state}"])
    metrics_df = {'MSE':MSE_imgs, 'SSIM': SSIM_imgs, 'PSNR':PSNR_imgs, 'NRMSE':NRMSE_imgs}
    metrics_df = pd.DataFrame(metrics_df)
    return metrics_mean_df, metrics_df

def Test_model_metrics(net_G:GaussianDiffusion, test_dataloader, input_state):
    df_metrics = {'AUC':[],'MSE':[], 'SSIM':[], 'PSNR':[]} #
    # df_metrics = []
    print(f"输入状态：{input_state}")
    net_G.eval()
    for idx, batch_data in enumerate(tqdm(test_dataloader), 1):
        input_img, target_img = get_data_from_batch(batch_data, input_state)
        with torch.no_grad():
            fake_img = net_G.ddim_sample(shape =target_img.shape, source_img = input_img.to(device), 
                sampling_timesteps = 16, random_seed=42).squeeze(dim=1).detach().cpu().numpy()
        # 在-1+1 maxminnorm计算全局指标
        AUC, MSE, SSIM, PSNR= cal_metric_list(fake_img, target_img.squeeze(dim=1).numpy(), method='8bit', norm=True)
        df_metrics['AUC'].extend(AUC)
        df_metrics['MSE'].extend(MSE)
        df_metrics['SSIM'].extend(SSIM)
        df_metrics['PSNR'].extend(PSNR)
        # df_metrics['NRMSE'].extend(NRMSE)
        
    df_metrics = pd.DataFrame(df_metrics)
    df_metrics_sum = pd.concat([df_metrics.mean(axis=0), df_metrics.std(axis=0)], axis=1).rename(columns={0:f"input_{input_state}_mean", 1:f"input_{input_state}_std"}).T
    print(df_metrics_sum)

    return df_metrics, df_metrics_sum

RMSE不是合适的指标！！！！！用MSE吧
8bit norm
AUC        0.976465
MSE      216.410033
SSIM       0.816135
PSNR      27.058400
NRMSE           inf
dtype: float64


In [4]:
def test_all_state_input(save_path="./DiffusionSpinalMRISynthesis/Test_models/metrics/"):
    metrics_mean_df_allstates = []
    for i in range(7):
        df_metrics, df_metrics_sum = Test_model_metrics(net_G, test_spinal_dataloader, i)
        metrics_mean_df_allstates.append(df_metrics_sum)
        df_metrics.to_csv(save_path+f"1231_diffusion_maskchannels_metrics_state_input_randomseed42_{str(i)}.csv")
    metrics_mean_df_allstates= pd.concat(metrics_mean_df_allstates, axis=0)
    metrics_mean_df_allstates.to_csv(save_path+"diffusion_maskchannels_allstates_input_sum_1231.csv")

In [None]:
test_all_state_input()