In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from vnet import VNetWithDiagnosis
import nibabel as nib
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

def load_model(model_path, device):
    model = VNetWithDiagnosis()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def test_model(model, mri_tensor, diag_data, device):
    with torch.no_grad():
        mri_tensor = mri_tensor.to(device)
        diag_tensor = torch.tensor(diag_data, dtype=torch.float32).unsqueeze(0).to(device)
        output = model(mri_tensor, diag_tensor)
        return output.cpu().numpy()






In [2]:


def save_tensor_as_nifti(tensor, output_filename):
    if isinstance(tensor, torch.Tensor):
        if tensor.is_cuda:
            tensor = tensor.cpu() 
        array = tensor.numpy()
    elif isinstance(tensor, np.ndarray):
        array = tensor
    else:
        raise TypeError("Input must be a torch.Tensor or a numpy.ndarray.")
    
    array = np.squeeze(array)

    nifti_img = nib.Nifti1Image(array, affine=np.eye(4))

    nib.save(nifti_img, output_filename)
    print(f"Saved NIfTI image to {output_filename}")




In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'trained_vnet.pth'

model = load_model(model_path, device)

mri_path = "../../Aims-Tbi/scan_0106_T1.nii.gz"
mri_img = nib.load(mri_path).get_fdata()
mri_tensor = torch.tensor(mri_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # 添加必要的维度

age = float(15.18)
sex = float(2)
tsi = float(2)
scan_manufacturer = float(2)
diag_data = [age, sex, tsi, scan_manufacturer]

output = test_model(model, mri_tensor, diag_data, device)
print("Model output:", output.shape)
save_tensor_as_nifti(output, 'output_image.nii.gz')