In [2]:
import glob
import shutil
import SimpleITK as sitk
import numpy as np
import os
from collections import Counter
from multiprocessing import Pool
from numpy.lib.npyio import save
from numpy.lib.shape_base import _dstack_dispatcher
from numpy.lib.type_check import imag
from pathlib import Path
import pathlib
import multiprocessing
import math
import functools
from functools import partial


In [3]:


def dcm2nii(image_3D, nii_path):
    # 将整合后的数据转为array，并获取dicom文件基本信息
    image_array = sitk.GetArrayFromImage(image_3D)  # z, y, x
    origin = image_3D.GetOrigin()  # x, y, z
    spacing = image_3D.GetSpacing()  # x, y, z
    direction = image_3D.GetDirection()  # x, y, z
    # 将array转为img，并保存为.nii.gz
    image3 = sitk.GetImageFromArray(image_array)
    image3.SetSpacing(spacing)
    image3.SetDirection(direction)
    image3.SetOrigin(origin)
    sitk.WriteImage(image3, nii_path)


def dcmSeries2nii(DirectoryPath):
    reader = sitk.ImageSeriesReader()
    # todo 获取dcm切片中的序列号
    series_IDs = reader.GetGDCMSeriesIDs(DirectoryPath)
    file_reader = sitk.ImageFileReader()
    # todo 将不同序列号下的scan 转换成nii文件
    # 生成nii文件夹
    path = DirectoryPath + "-nii\\"
    if not os.path.exists(path):
        os.makedirs(path)
    # 将mask文件复制到nii文件夹下
    folder = glob.glob(DirectoryPath + "\\*")
    i = 1
    for file in folder:
        if file.endswith("nii") or file.endswith("nii.gz"):
            shutil.copy(file, path+"mask"+str(i)+".nii.gz")
            i += 1
    for series_id in series_IDs:
        dicom_names = reader.GetGDCMSeriesFileNames(DirectoryPath+"\\", series_id)
        file_reader.SetFileName(dicom_names[0])
        reader.SetFileNames(dicom_names)
        file_reader.ReadImageInformation()
        # uid
        series_uid = file_reader.GetMetaData("0020|000e")#j将dcm图像序列转化成为
        # description
        series_desc = file_reader.GetMetaData("0008|103e")
        print("series_uid:    ", series_uid)
        print("series_desc:    ", series_desc)
        try:
            image3D = reader.Execute()
            nii_path = path + str(series_uid)+"("+str(series_desc).strip()+")"+".nii"
            dcm2nii(image3D, nii_path)
        except:
            # 将文件复制到另一个文件夹
            case_name = DirectoryPath.split("\\")[-1]
            new_dir = "D:\\EGFR1_MARKED\\POSITIVE\\3-4abnormal\\"+case_name
            if not os.path.exists(new_dir):
                os.makedirs(new_dir)
            shutil.copytree(DirectoryPath, new_dir)
            continue

def read_and_convert_nii_to_array(nii_path):
    img = sitk.ReadImage(nii_path)
    img_array = sitk.GetArrayFromImage(img)
    return img_array


def match_series_mask(DirectoryPath):
    '''
    对路径下的mask和影像分别存储
    '''
    path = DirectoryPath + "-nii\\*"
    nii_file_folder = glob.glob(path)
    mask_file_folder=list(filter(lambda x:'mask' in x,nii_file_folder))
    nii_file_new_folder=list(filter(lambda x:'mask' not in x,nii_file_folder))
    # 读取mask的张数

    count = 0

    for mask_file in mask_file_folder:
 
        mask_img = read_and_convert_nii_to_array(mask_file)
        # todo 将序列文件和标注文件匹配
        for slice_file in nii_file_new_folder:

            slice_img = read_and_convert_nii_to_array(slice_file)
            if slice_img.shape == mask_img.shape:
                count += 1
    return count


def get_context_nodule_coordinate(*params,ratio=1.5):#params一般是6个坐标
    return [calculate_axis_coordinate(x,y) for x,y in zip(params[::2], params[1::2])]

def calculate_axis_coordinate(a,b,ratio):
    return (a+b+ratio*a-ratio*b)//2,(a+b+ratio*b-ratio*a)//2


def classify_nodule_and_relabel(valid_z_pathes,img_array,mask_array):
    '''
    object as the fomal params,the origin num will be changed
    '''
    nodule_start=1
    for i in range(1,len(valid_z_pathes)):
        if valid_z_pathes[i]-valid_z_pathes[i-1]>1:
            nodule_start+=1
            mask_array[valid_z_pathes[i]][ mask_array[valid_z_pathes[i]]==1]=nodule_start
        else:
            if valid_z_pathes[i]-valid_z_pathes[i-1]==1:#这是层数挨着的patch
                if np.isin(1,mask_array[valid_z_pathes[i]]==img_array[valid_z_pathes[i-1]]):#证明两个patch是一个结节的
        
                    mask_array[i][ mask_array[valid_z_pathes[i]]==1]=nodule_start
                else:
                    nodule_start+=1
                    mask_array[valid_z_pathes[i]][mask_array[valid_z_pathes[i]]==1]=nodule_start
    return i


def save_np_array(array,path):#将文件上一级目录下的文件夹创建，然后保存文件夹
    path.parents[0].mkdir(parents=True, exist_ok=True)
    np.save(path,array)
    
def make_dataset(DirectoryPath:Path,slice_path_list,mask_path_list,ratio=1.5):
    '''
    classification task:crop 3D cube and 3D cube with context information
    segmatation task: use point cloud as the segmatation label
    3D detection task: save xmin xmax ymin ymax zmin zmax 
    All object is lung nodule,so we just not label that
    '''

    nodule_path=Path(DirectoryPath/"nodule")
    context_nodule_path=Path(DirectoryPath/"context_nodule")
    position_path=Path(DirectoryPath/'point_cloud')
    detection_path=Path(DirectoryPath/'detection')



    for index in range(len(slice_path_list)):
        img_nii = sitk.ReadImage(slice_path_list[index])
        img_array = sitk.GetArrayFromImage(img_nii)

        mask_nii = sitk.ReadImage(mask_path_list[index])
        mask_array = sitk.GetArrayFromImage(mask_nii)

        res=np.where(mask_array==1)#返回的是 x y z 轴,x y z 分别对应不同的轴方向，未必是原来那样的
        valid_z_pathes=list(dict(Counter(res[0])).keys())
        nodule_start=classify_nodule_and_relabel(valid_z_pathes,img_array,mask_array)

        for i in range(1,nodule_start+1):
            zlist, ylist, xlist = np.where(mask_array==i)
            position=np.argwhere(mask_array==1)    
            position=position/np.array(mask_array.shape)[:,None].T#用来对点云坐标数据进行归一化
            #就不存储类别了，因为所有的点云的类别都是肺结节
            position_path=nodule_path/f"position{index}_{i}.npy"
            save_np_array(position,position_path)#保存点云坐标

            xmin,xmax,ymin,ymax,zmin,zmax= xlist[0],xlist[-1],ylist[0],ylist[-1],zlist[0],zlist[-1]
            detection_label=np.array([xmin,xmax,ymin,ymax,zmin,zmax])
            detection_path=detection_path/f"detection{index}_{i}.npy"
            save_np_array(detection_label,detection_path)#保存检测坐标
            
            cropped_nodule = img_array[zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]
            (context_xmin,context_xmax),(context_ymin,context_ymax),(context_zmin,context_zmax)=get_context_nodule_coordinate(xmin,xmax,ymin,ymax,zmin,zmax,ratio)
            cropped_nodule1 = img_array[context_zmin:context_zmax,context_ymin:context_ymax+1,context_xmin:context_xmax+1]

            # todo 3D展示结节
            # todo 存储结节（存储为.nii)
            #print('first saved')
            cropped_nodule_path =nodule_path/ f"nodule{index}_{i}.nii" 
            nodule_img = sitk.GetImageFromArray(cropped_nodule)
            sitk.WriteImage(nodule_img, cropped_nodule_path)

            cropped_nodule_path = context_nodule_path/f"nodule{index}_{i}.nii" 
            nodule_img = sitk.GetImageFromArray(cropped_nodule1)
            sitk.WriteImage(nodule_img, cropped_nodule_path)

def run_multi_process(item_list, n_proc, func, with_proc_num=False):
    tasks = chunk(item_list, n_proc)
    if with_proc_num:
        for i in range(len(tasks)):
            tasks[i] = (i, tasks[i])
    with multiprocessing.Pool(processes=n_proc) as pool:
        results = pool.map(func, tasks)
    return results


def chunk(list, n):
    result = []
    for i in range(n):
        result.append(list[math.floor(i / n * len(list)):math.floor((i + 1) / n * len(list))])
    return result


def data_processing(task,dcm_to_nii=False):
    #用来综合执行所有流程,包括从dcm到nii，然后从nii到最后的所有流程
    if task.endswith("resistant mutation"):
        print("开始处理病例：    ", task)
        if dcm_to_nii:
            dcmSeries2nii(task)
        count = match_series_mask(task)
        f = open("C:\\Users\\wyh196646\\Desktop\\test\\not_match.txt", "a+")
        if count == 0:
            f.write(task + "\n")
        make_dataset(task,)




In [None]:
if __name__ == '__main__':
    directory = r"C:\Users\wyh196646\Desktop\test\resistant mutation\*"
    case_list = glob.glob(directory)
    run_multi_process(case_list, 50, partial(data_processing,dcm_to_nii=True), with_proc_num=True)
    print('all data have been processed')
