In [None]:
# Evaluation script for trained MMGSN and Pix2Pix models
# This script loads pre-trained models and generates synthetic images on test data
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import os
import datetime
import time
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn

from torch import optim
from torch.utils.data import Dataset, DataLoader
import sys 
sys.path.append("..") 
from utils.dataset_png import *
from model.MMGSN_new.syn_model_2D_T2FS import *
from utils.test_metrics_2D import *
from utils.FolderDataLoader import *
from utils.metrics_calculate_2D import *

from time import time
print(torch.cuda.device_count()) 

In [None]:
# Load pre-trained models for comparison

model_STS = torch.load("/your/path/MRI_SYNS/model_save/STS_foldAll.pt", map_location='cpu')['model']

# model_pix2pix = torch.load("/your/path/MRI_SYNS/model_save/pix2pix_foldAll.pt", map_location='cpu')['model']

device = 'cuda:0'
spinal_test_dir = "/your/path/MRI_SYNS/Data/your_regis_imgdata" # Test dataset path


In [None]:
@torch.no_grad()
def Test_model_patients( spinal_test_dir, model_img, device='cuda:0'):
    test_patients_dataset = MRI_2Dpng_Dataset(data_dirpath=spinal_test_dir, transform=test_2Dpng_transforms)
    test_dataloader = DataLoader(test_patients_dataset, batch_size=128, shuffle=False, num_workers=32, pin_memory=True)

    model_img.to(device)
    model_img.eval()

    start_time = time()
    for batch_data in tqdm(test_dataloader):
        input_img, target_img = batch_data[:,:2,:,:].to(device), batch_data[:,2:3,:,:].cpu().detach()

        fake_img = model_img(input_img.to(device)).squeeze(dim=1)
        fake_img = norm_layerHW(fake_img.detach().cpu().numpy(), method='8bit')


    allocated_memory = torch.cuda.max_memory_allocated(device)
    reserved_memory = torch.cuda.max_memory_reserved(device)

    print(f"已分配显存: {allocated_memory / (1024 ** 2):.2f} MB")
    print(f"保留显存: {reserved_memory / (1024 ** 2):.2f} MB")
    end_time = time()
    print(end_time - start_time)


In [None]:
Test_model_patients(spinal_test_dir, model_img=model_STS, device='cuda:0')

# Test_model_patients(spinal_test_dir, model_img=model_pix2pix, device='cuda:0')

In [None]:
# Process images individually and save pictures for each patient.
@torch.no_grad()
def Test_model_onepatient(output_path, spinal_test_dir, model_img, device='cuda:1'):
    if output_path !=None:
        os.makedirs(output_path, exist_ok=True)  
    test_onepatient_dataset = MRI_2Dpng_Dataset(data_dirpath=spinal_test_dir, transform=test_2Dpng_transforms)

    model_img.to(device)
    model_img.eval()

    start_time = time()
    for i in tqdm(range(len(test_onepatient_dataset))):
  
        input_img = test_onepatient_dataset[i][0:2,:,:].unsqueeze(0).to(device)  # [1, 2, H, W]
        target_img = test_onepatient_dataset[i][2:3,:,:].unsqueeze(0)  # [1, 1, H, W]

        fake_img = model_img(input_img.to(device)).squeeze(dim=1)
        fake_img = norm_layerHW(fake_img.detach().cpu().numpy(), method='8bit')

        target_img = (255*(target_img.squeeze().cpu().numpy() +1)/2).astype(np.uint8)
        # Concatenate ground truth and generated image for easy visual comparison
        img_pred = Image.fromarray(np.concatenate([
            np.transpose(target_img),
            np.transpose(fake_img[0,:,:])  
        ], axis=-1))
        # Create descriptive filename: patientID@true_modelname@disease_layer
        if output_path !=None:
            patient_save_path = os.path.join(output_path, test_onepatient_dataset.fnames[i].split('@')[-4])
            os.makedirs(patient_save_path, exist_ok=True)
            i_savepath_pred = os.path.join(patient_save_path, test_onepatient_dataset.fnames[i].split('@')[-4]+'@true_STS@'+test_onepatient_dataset.fnames[i].split('@')[-1])
            img_pred.save(i_savepath_pred)

        # 显存清理
        # del input_img, target_img, fake_img
        # torch.cuda.empty_cache()
    end_time = time()
    print(end_time - start_time)

In [None]:
output_path = "/your/path/Spinal_T2FS_Simulator/results_mmgsn_internal_test"
Test_model_onepatient(output_path, spinal_test_dir, model_img=model_STS,device='cuda:0')

# output_path = "/your/path/MRI_SYNS/results_pix2pix_internal"
# Test_model_onepatient(output_path, spinal_test_dir, model_img=model_pix2pix,device='cuda:0')