In [1]:
import numpy as np
import nibabel as nib
import pydicom
import matplotlib.pyplot as plt
import os
import shutil
from glob import glob
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import transforms as T
from IPython import display
import SimpleITK as sitk
from ipywidgets import interact, fixed
import random 
from tqdm import tqdm
from PIL import Image
from multiprocessing import Pool

In [2]:
def output_dcm_meta(dicom_file_path):
    ds = pydicom.dcmread(dicom_file_path)
    print("Patient's Name:", ds.PatientName)
    print("Modality (Imaging Type):", ds.Modality)
    print("Image Position (Patient):", ds.ImagePositionPatient)
    print("Image Orientation (Patient):", ds.ImageOrientationPatient)
    print("Slice Thickness:", ds.SliceThickness)
    print("Pixel Spacing:", ds.PixelSpacing)
    print("Spacing Between Slices:", ds.SpacingBetweenSlices)


In [3]:
def read_dicom_series(folder_path):
    # 读取 DICOM 序列
    reader = sitk.ImageSeriesReader()
    dicom_files = reader.GetGDCMSeriesFileNames(folder_path)
    reader.SetFileNames(dicom_files)
    image = reader.Execute()
    return image

In [4]:
def read_decom(patient_path):
    channel_names = ["T1", "T2", "T2FS", "T1CE"]
    data_all = []
    if len(os.listdir(patient_path)) !=4:
        print('len error:', patient_path, len(os.listdir(patient_path)))
    for i, channel_name in enumerate(channel_names):
        channel_path = patient_path.split('/')[-1]+'_'+channel_name
        channel_path = os.path.join(patient_path, channel_path)
        image_one_channel = read_dicom_series(channel_path)
        data_all.append(image_one_channel)

    print('*'*30)
    return data_all

In [5]:
def norm(imgdata):
    channels, layers, _, _ = imgdata.shape[0], imgdata.shape[1], imgdata.shape[2], imgdata.shape[3]
    max_value = np.max(imgdata.reshape(channels, layers,-1),axis=-1).reshape(channels, layers,1,1)
    min_value = np.min(imgdata.reshape(channels, layers,-1),axis=-1).reshape(channels, layers,1,1)
    # print(f"max_value:{max_value.max()}, min_value:{min_value.min()}")
    imgdata = (imgdata-min_value)/(max_value-min_value+1e-10)
    imgdata = (255 * imgdata).astype(np.uint8) #转成int8
    return imgdata

In [6]:
# 已经优化
def perform_registration(fixed_image, moving_image):
    # 创建 Rigid Transform
    fixed_image, moving_image= sitk.Cast(fixed_image, sitk.sitkFloat32), sitk.Cast(moving_image, sitk.sitkFloat32)
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image, 
        moving_image, 
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY)

    # 设置配准条件
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=64)
    registration_method.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.25)
    # sitkNearestNeighbor, sitkLinear, sitkGaussian, sitkBSpline2, sitkBSpline3
    registration_method.SetInterpolator(sitk.sitkLinear)
    # registration_method.SetOptimizerAsGradientDescentLineSearch 
    # SetOptimizerAsGradientDescent SetOptimizerAsGradientDescentLineSearch SetOptimizerAsRegularStepGradientDescent
    registration_method.SetOptimizerAsGradientDescentLineSearch( #这个更好
        learningRate=1, 
        numberOfIterations=2000, 
        convergenceMinimumValue=1e-6, 
        convergenceWindowSize=20)

    # 设置多分辨率参数
    registration_method.SetOptimizerScalesFromPhysicalShift()
    # registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    # registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 1])
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # 执行配准
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    final_transform = registration_method.Execute(fixed_image, moving_image)
    
    result_image = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())

    return result_image

In [7]:
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1, 2, figsize=(10, 8))
    # Draw the fixed image in the first subplot.
    plt.subplot(1, 2, 1)
    plt.imshow(fixed_npa[:, :, fixed_image_z].transpose(1,0), cmap=plt.cm.Greys_r)
    plt.title("fixed image")
    plt.axis("off")

    # Draw the moving image in the second subplot.
    plt.subplot(1, 2, 2)
    plt.imshow(moving_npa[:, :, moving_image_z].transpose(1,0), cmap=plt.cm.Greys_r)
    plt.title("moving image")
    plt.axis("off")

    plt.show()

def display_images_with_alpha(image_z, alpha, fixed, moving):
    img1 = fixed[image_z,:, : ]
    img2 = moving[image_z, :, :]
    plt.imshow(img1, cmap=plt.cm.gray, alpha=1-alpha)
    plt.imshow(img2, cmap=plt.cm.Reds, alpha=alpha)
    plt.axis("off")
    plt.show()

def show_slices_2D_with_slider(img_data: np.array, merge=True, time = 0.15):
    print(f"img_data shape: {img_data.shape}")
    channels, layers, _, _ = img_data.shape
    for idx in range(layers):
        if merge:
            plt.subplots(figsize=(4, 4), dpi=200)
            plt.subplots_adjust(bottom=0.25)  # Adjust the bottom to make space for the slider
            plt.cla()
            plt.title(f"{idx}")
            plt.imshow(img_data[:, idx, :, :].transpose(2,1,0),  origin='lower') #(3,266,256)——(256,256,3)
        else:
            plt.subplots(figsize=(6, 6), dpi=300)
            for ch in range(channels):
                plt.subplot(1, channels, ch+1) 
                plt.cla()
                plt.title(f"{idx}")
                plt.imshow(img_data[ch,idx,:, :], cmap="gray", origin="lower")
        display.clear_output(wait=True)
        plt.pause(time)

    plt.show(block=True)
    plt.close()

def show_slices_2D_with_idx(img_data: np.array, merge=True, idx = 6):
    print(f"img_data shape: {img_data.shape}")
    channels, layers, _, _ = img_data.shape

    if merge:
        plt.subplots(figsize=(4, 4), dpi=200)
        plt.subplots_adjust(bottom=0.25)  # Adjust the bottom to make space for the slider
        plt.title(f"{idx}")
        plt.imshow(img_data[:, idx, :, :].transpose(1,2,0), cmap='gray', origin='upper') #(3,266,256)——(256,256,3)
    else:
        plt.subplots(figsize=(6, 6), dpi=300)
        for ch in range(channels):
            plt.subplot(1, channels, ch+1) 
            plt.title(f"{idx}")
            plt.imshow(img_data[ch,idx,:, :], cmap="gray", origin="lower")

    plt.show(block=True)
    plt.close()

In [8]:
# 感觉T2FS更好？？？
# 或者以T1CE为基准，配准T1 T2 T2FS 完全不行
# 配准完resize和norm？不用resize，resize在配准这一步已经同一，只需要norm
def regi_patient(patient_path, regi_method = perform_registration):
    print(f"Path: {patient_path}")
    T1_img, T2_img, T2FS_img, T1CE_img = read_decom(patient_path)
    T1CE_regi_img = regi_method(fixed_image=T2FS_img, moving_image=T1CE_img) #以T2FS为标准
    #谁和T2FS层数不一致，就以T2FS为标准转换谁, 必须保证shape在每一个维度都是匹配的，像素大小也一样
    if T1_img.GetSize() != T2FS_img.GetSize():
        T1_img = regi_method(fixed_image=T2FS_img, moving_image=T1_img)

    if T2_img.GetSize() != T2FS_img.GetSize():
        T2_img = regi_method(fixed_image=T2FS_img, moving_image=T2_img)

    # 配准后转成array
    T1_img_array = sitk.GetArrayFromImage(T1_img)
    T2_img_array = sitk.GetArrayFromImage(T2_img)
    T2FS_img_array = sitk.GetArrayFromImage(T2FS_img)
    T1CE_img_array = sitk.GetArrayFromImage(T1CE_img)
    T1CE_regi_img_array = sitk.GetArrayFromImage(T1CE_regi_img)    

    return T1_img_array, T2_img_array, T2FS_img_array, T1CE_img_array, T1CE_regi_img_array



In [None]:
patient_path = "./your_test"
T1_img_array, T2_img_array, T2FS_img_array, T1CE_img_array, T1CE_regi_img_array = regi_patient(patient_path)

In [None]:
# merge_img_data = norm(np.stack([T1CE_regi_img_array, T2FS_img_array, T1_img_array], axis=0))
img_data_channels = norm(np.stack([T1_img_array, T2_img_array, T2FS_img_array, T1CE_regi_img_array], axis=0))
merge_img_data = img_data_channels[[3,2,0],:]
img_test = Image.fromarray(merge_img_data[:,6,:,:].transpose(1,2,0))
img_test.save("ouput.png")
show_slices_2D_with_idx(merge_img_data)

In [None]:
interact(
    display_images_with_alpha,
    image_z=(0, T1_img_array.shape[0] - 1),
    alpha=(0.0, 1.0, 0.05),
    fixed=fixed(T1_img_array),
    moving=fixed(T1CE_regi_img_array),
);

In [37]:
main_dir = "./DiffusionSpinalMRISynthesis/Data_MRI"
# 多进程调用函数
def save_array_img_multiprocess(args):
    try:
        patient_path, output_path, save_test_path = args
        patient_name = patient_path.split('/')[-1]
        T1_img_array, T2_img_array, T2FS_img_array, _, T1CE_regi_img_array = regi_patient(patient_path) #全部通过
        # (channels, layers, H, W)
        img_data_channels = norm(np.stack([T1_img_array, T2_img_array, T2FS_img_array, T1CE_regi_img_array], axis=0))
        test_img = Image.fromarray(img_data_channels[[3,2,0],6].transpose(1,2,0))
        testpath = os.path.join(save_test_path, patient_name+'_test_6.png')
        test_img.save(testpath)
        for idx in range(img_data_channels.shape[1]):
            img = Image.fromarray(img_data_channels[:,idx,:,:].transpose(1,2,0))
            savepath = os.path.join(output_path, patient_name+'_sagittal_'+str(idx)+'.png')
            img.save(savepath)
    except:
        print("error: ", patient_path)
        
def main_multiprocess(path_name,  processes=32, main_dir = main_dir):
    input_path = main_dir+path_name
    output_path = main_dir+path_name+'_regi'
    save_test_path = main_dir+path_name+'_regi_forsee'
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(save_test_path, exist_ok=True)
    args_list = [(os.path.join(input_path, patient_name), output_path, save_test_path) for patient_name in os.listdir(input_path)]
    if len(args_list) < processes:
        processes = len(args_list)
    with Pool(processes=processes) as pool:  
        pool.map(save_array_img_multiprocess, tqdm(args_list))

In [None]:
main_multiprocess('your_img_path_name')

In [2]:
#根据患者编号分配训练集和测试集
all_img_dir = "./DiffusionSpinalMRISynthesis/all_spinal_MRI"
all_img_pathlist = os.listdir(all_img_dir)
patients_id = sorted(list(set([pathname.split('_')[1] for pathname in all_img_pathlist]))) #从小到大
print("总患者数: ", len(patients_id))
random.seed(42)
ratio = 0.8
train_patient_id = random.sample(patients_id, int(ratio * len(patients_id)))
test_patient_id = list(set(patients_id) - set(train_patient_id))
train_dir = "./DiffusionSpinalMRISynthesis/Data_MRI/train_spinal_MRI"
# os.makedirs(train_dir, exist_ok=True)
train_id_pathlist = []
for id in train_patient_id:
    train_id_pathlist.extend(glob(os.path.join(all_img_dir, '*'+id+'*')))
print("train_id_pathlist: ", len(train_id_pathlist))
# for path in train_id_pathlist:
#     shutil.copy(path, train_dir)

test_dir = "./DiffusionSpinalMRISynthesis/Data_MRI/test_spinal_MRI"
# os.makedirs(test_dir, exist_ok=True)
test_id_pathlist = []
for id in test_patient_id:
    test_id_pathlist.extend(glob(os.path.join(all_img_dir, '*'+id+'*')))
print("test_id_pathlist: ", len(test_id_pathlist))
# for path in test_id_pathlist:
#     shutil.copy(path, test_dir)

总患者数:  1012
train_id_pathlist:  10875
test_id_pathlist:  2712
