## 測試nnUNet training code 
想要看出經過augmentation後，train的input長什麼樣子，總共放5張圖: 
1. 原圖 
2. Resample 
3. Resample + augmentation 
4. Resample + SimulateLowResolutionTransform 
5. Resample + augmentation + SimulateLowResolutionTransform 

In [None]:
import os
import socket
from typing import Union, Optional

import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn

In [None]:
import inspect
import multiprocessing
import os
import shutil
import traceback
from asyncio import sleep
from copy import deepcopy
from typing import Tuple, Union, List

import nnunetv2
import numpy as np
import torch
from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.transforms.utility_transforms import NumpyToTensor
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
    save_json
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import export_prediction_from_softmax
#from nnunetv2.inference.sliding_window_prediction import predict_sliding_window_return_logits, compute_gaussian
from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
from nnunetv2.utilities.file_path_utilities import get_output_folder, should_i_save_to_file, check_workers_busy
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels, convert_labelmap_to_one_hot
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder

import matplotlib.pyplot as plt
from skimage.filters import threshold_multiotsu, gaussian, threshold_otsu, frangi
from skimage.measure import label, regionprops, regionprops_table
import time

from pathlib import Path
import numpy as np
import pandas as pd
import nibabel as nib
from scipy.ndimage import binary_dilation
import matplotlib.pyplot as plt
from scipy import ndimage  # pip install scipy
import scipy.stats
from skimage.morphology import skeletonize
from sklearn.model_selection import train_test_split  # pip install scikit-learn
from sklearn.utils import class_weight
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score, cohen_kappa_score, matthews_corrcoef
from itertools import cycle, product
import random
import cv2
import matplotlib
#%matplotlib notebook
%matplotlib inline
#matplotlib.use('Agg') 
import matplotlib.pyplot as plt

In [None]:
class PreprocessAdapter(DataLoader):
    def __init__(self, list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[List[None], List[str]],
                 preprocessor: DefaultPreprocessor, output_filenames_truncated: List[str],
                 plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager,
                 num_threads_in_multithreaded: int = 1):
        self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \
            preprocessor, plans_manager, configuration_manager, dataset_json

        self.label_manager = plans_manager.get_label_manager(dataset_json)

        super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)),
                         1, num_threads_in_multithreaded,
                         seed_for_shuffle=1, return_incomplete=True,
                         shuffle=False, infinite=False, sampling_probabilities=None)

        self.indices = list(range(len(list_of_lists)))

    def generate_train_batch(self):
        idx = self.get_indices()[0]
        files = self._data[idx][0]
        seg_prev_stage = self._data[idx][1]
        ofile = self._data[idx][2]
        # if we have a segmentation from the previous stage we have to process it together with the images so that we
        # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
        # preprocessing and then there might be misalignments
        data, seg, data_properites = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager,
                                                                self.configuration_manager,
                                                                self.dataset_json)
        if seg_prev_stage is not None:
            seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)
            data = np.vstack((data, seg_onehot))

        if np.prod(data.shape) > (2e9 / 4 * 0.85):
            # we need to temporarily save the preprocessed image due to process-process communication restrictions
            np.save(ofile + '.npy', data)
            data = ofile + '.npy'

        return {'data': data, 'data_properites': data_properites, 'ofile': ofile}

In [None]:
def load_what_we_need(model_training_output_dir, use_folds, checkpoint_name):
    # we could also load plans and dataset_json from the init arguments in the checkpoint. Not quite sure what is the
    # best method so we leave things as they are for the moment.
    dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
    plans = load_json(join(model_training_output_dir, 'plans.json'))
    plans_manager = PlansManager(plans)

    if isinstance(use_folds, str):
        use_folds = [use_folds]

    parameters = []
    for i, f in enumerate(use_folds):
        f = int(f) if f != 'all' else f
        checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
                                map_location=torch.device('cpu'))
        if i == 0:
            trainer_name = checkpoint['trainer_name']
            configuration_name = checkpoint['init_args']['configuration']
            inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
                'inference_allowed_mirroring_axes' in checkpoint.keys() else None

        parameters.append(checkpoint['network_weights'])

    configuration_manager = plans_manager.get_configuration(configuration_name)
    # restore network
    num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
    trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                trainer_name, 'nnunetv2.training.nnUNetTrainer')
    network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager,
                                                       num_input_channels, enable_deep_supervision=False)
    return parameters, configuration_manager, inference_allowed_mirroring_axes, plans_manager, dataset_json, network, trainer_name

In [None]:
def auto_detect_available_folds(model_training_output_dir, checkpoint_name):
    print('use_folds is None, attempting to auto detect available folds')
    fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False)
    fold_folders = [i for i in fold_folders if i != 'fold_all']
    fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))]
    use_folds = [int(i.split('_')[-1]) for i in fold_folders]
    print(f'found the following folds: {use_folds}')
    return use_folds

In [None]:
import warnings

import numpy as np
import torch
from typing import Union, Tuple, List
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from scipy.ndimage import gaussian_filter
from torch import nn

from nnunetv2.utilities.helpers import empty_cache, dummy_context

In [None]:
def data_translate(img, nii):
    img = np.swapaxes(img,0,1)
    img = np.flip(img,0)
    img = np.flip(img, -1)
    header = nii.header.copy() #抓出nii header 去算體積 
    pixdim = header['pixdim']  #可以借此從nii的header抓出voxel size
    if pixdim[0] > 0:
        img = np.flip(img, 1)  
    # img = np.expand_dims(np.expand_dims(img, axis=0), axis=4)
    return img

#會使用到的一些predict技巧
def data_translate_back(img, nii):
    header = nii.header.copy() #抓出nii header 去算體積 
    pixdim = header['pixdim']  #可以借此從nii的header抓出voxel size
    if pixdim[0] > 0:
        img = np.flip(img, 1)  
    img = np.flip(img, -1)
    img = np.flip(img,0)
    img = np.swapaxes(img,1,0)
    # img = np.expand_dims(np.expand_dims(img, axis=0), axis=4)
    return img

In [None]:
# sampling in aneurysm_labels
# 同時5種不同成像切片
# 因為resample可能會比較小顆，所以random label的選取使用他
# 連brain mask 跟 augmentation 都不需要了
def rand_crop_sampling(image_arr, label_arr, 
                       img_resample, label_resample,
                       img_aug, label_aug,
                       img_low, label_low,
                       img_auglow, label_auglow,
                       nag_sample_mask=None,
                       size=(32,32,16), sample_num=1, pos_rate=0.2, lesion_lb=3):
    lesion_mask = np.zeros_like(label_resample, dtype=bool)
    
    ## get sample centers
    center_list = []
    
    # positive sampling
    if (label_resample.max() > 0)and(label_resample.max() >= lesion_lb):  # on lesion
        lesion_mask = label_resample==lesion_lb
        centers = np.stack(np.where(lesion_mask), axis=1)
        centers = [xyz for xyz in centers.tolist()]
        random.shuffle(centers)
        center_list.extend(centers[:int(sample_num*pos_rate)])
        
    # other sampling
    if np.any(nag_sample_mask):  # on custom area
        mask = nag_sample_mask
    else:
        mask = image_arr > 0
    if np.any(lesion_mask):
        mask = mask & (label_arr==0)
    centers = np.stack(np.where(mask), axis=1)
    centers = [xyz for xyz in centers.tolist()]
    random.shuffle(centers)
    center_list.extend(centers[:sample_num])
    center_list = center_list[:sample_num]
    del mask, centers
    
    ## pad image and label
    bigger_size = np.ceil(np.array(size) / np.cos(45 * np.pi / 180)).astype('uint16')
    p0, p1, p2 = np.int32(bigger_size // 2)
    
    #以下為了符合旋轉，先做padding的影像
    pad_image_arr = np.pad(image_arr, ((p0,p0),(p1,p1),(p2,p2)), 'constant')
    pad_label_arr = np.pad(np.uint8(label_arr>0), ((p0,p0),(p1,p1),(p2,p2)), 'constant')  # <-label_arr
    
    pad_image_resample = np.pad(img_resample, ((p0,p0),(p1,p1),(p2,p2)), 'constant')
    pad_label_resample = np.pad(np.uint8(label_resample>0), ((p0,p0),(p1,p1),(p2,p2)), 'constant')  # <-label_arr
    
    pad_image_aug = np.pad(img_aug, ((p0,p0),(p1,p1),(p2,p2)), 'constant')
    pad_label_aug = np.pad(np.uint8(label_aug>0), ((p0,p0),(p1,p1),(p2,p2)), 'constant')  # <-label_arr    

    pad_image_low = np.pad(img_low, ((p0,p0),(p1,p1),(p2,p2)), 'constant')
    pad_label_low = np.pad(np.uint8(label_low>0), ((p0,p0),(p1,p1),(p2,p2)), 'constant')  # <-label_arr   

    pad_image_auglow = np.pad(img_auglow, ((p0,p0),(p1,p1),(p2,p2)), 'constant')
    pad_label_auglow = np.pad(np.uint8(label_auglow>0), ((p0,p0),(p1,p1),(p2,p2)), 'constant')  # <-label_arr  
    
    ## random crop sampling
    image_crop_stack = np.zeros((sample_num, *size), dtype=image_arr.dtype)
    label_crop_stack = np.zeros((sample_num, *size), dtype=label_arr.dtype)
    
    image_crop_resample = np.zeros((sample_num, *size), dtype=image_arr.dtype)
    label_crop_resample = np.zeros((sample_num, *size), dtype=label_arr.dtype)
    
    image_crop_aug = np.zeros((sample_num, *size), dtype=image_arr.dtype)
    label_crop_aug = np.zeros((sample_num, *size), dtype=label_arr.dtype)
    
    image_crop_low = np.zeros((sample_num, *size), dtype=image_arr.dtype)
    label_crop_low = np.zeros((sample_num, *size), dtype=label_arr.dtype)
    
    image_crop_auglow = np.zeros((sample_num, *size), dtype=image_arr.dtype)
    label_crop_auglow = np.zeros((sample_num, *size), dtype=label_arr.dtype)
        
    for i, idx in enumerate(np.random.randint(0, len(center_list), sample_num)):
        center = center_list[idx]
        #原圖
        bigger_image_crop = pad_image_arr[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]
        bigger_label_crop = pad_label_arr[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]

        # center crop to target size
        image_crop_stack[i] = bigger_image_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        label_crop_stack[i] = bigger_label_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        #resample
        bigger_image_crop = pad_image_resample[center[0]:center[0]+bigger_size[0], 
                                               center[1]:center[1]+bigger_size[1], 
                                               center[2]:center[2]+bigger_size[2]]
        bigger_label_crop = pad_label_resample[center[0]:center[0]+bigger_size[0], 
                                               center[1]:center[1]+bigger_size[1], 
                                               center[2]:center[2]+bigger_size[2]]

        image_crop_resample[i] = bigger_image_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                   p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                   p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        label_crop_resample[i] = bigger_label_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                   p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                   p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        #aug
        bigger_image_crop = pad_image_aug[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]
        bigger_label_crop = pad_label_aug[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]

        image_crop_aug[i] = bigger_image_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                              p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                              p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        label_crop_aug[i] = bigger_label_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                              p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                              p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        #low
        bigger_image_crop = pad_image_low[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]
        bigger_label_crop = pad_label_low[center[0]:center[0]+bigger_size[0], 
                                          center[1]:center[1]+bigger_size[1], 
                                          center[2]:center[2]+bigger_size[2]]

        image_crop_low[i] = bigger_image_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                              p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                              p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        label_crop_low[i] = bigger_label_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                              p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                              p2-(size[2]//2) : p2-(size[2]//2)+size[2]]        
        
        #auglow
        bigger_image_crop = pad_image_auglow[center[0]:center[0]+bigger_size[0], 
                                             center[1]:center[1]+bigger_size[1], 
                                             center[2]:center[2]+bigger_size[2]]
        bigger_label_crop = pad_label_auglow[center[0]:center[0]+bigger_size[0], 
                                             center[1]:center[1]+bigger_size[1], 
                                             center[2]:center[2]+bigger_size[2]]

        image_crop_auglow[i] = bigger_image_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                 p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                 p2-(size[2]//2) : p2-(size[2]//2)+size[2]]
        label_crop_auglow[i] = bigger_label_crop[p0-(size[0]//2) : p0-(size[0]//2)+size[0], 
                                                 p1-(size[1]//2) : p1-(size[1]//2)+size[1], 
                                                 p2-(size[2]//2) : p2-(size[2]//2)+size[2]] 
        
    return image_crop_stack, label_crop_stack, image_crop_resample, label_crop_resample, image_crop_aug, label_crop_aug, image_crop_low, label_crop_low, image_crop_auglow, label_crop_auglow

In [None]:
def easy_nor(sample):
    sample = (((sample - np.min(sample))/(np.max(sample) - np.min(sample)))*255).copy()
    sample[sample<0] = 0
    sample[sample>255] = 255
    return sample

def easy_show(sample, idx):
    show = np.expand_dims(sample[:,:,idx], axis=-1).copy()
    y_i, x_i, z_i = show.shape
    show_one = show.copy()
    show = np.concatenate([show, show_one],2)
    show = np.concatenate([show, show_one],2)
    show = show.astype('uint8')
    return show

def esay_canny(targe, idx):
    y_nor1 = targe[:,:,idx]    
    y_color = (y_nor1*255).astype('uint8')
    y_th = cv2.Canny(y_color, 128, 256).copy()
    y_th = y_th.astype('uint8')
    return y_th

def plot_multi_view(img1, label1, img2, label2, img3, label3, img4, label4, img5, label5):
    #先決定影像是否為正規化後，是的話就不用再做正規化, 先決定影像是否為4d，畫三張，原圖,label,疊圖
    #dim=3為標註123放在一起，dim=4為一標註一層，統一把標註轉為1層1個吧
    # 生成 1 到 100 之間的隨機正整數（包括 1 和 100）
    random_integer = random.randint(0, img1.shape[0]-1)
    sample1 = img1[random_integer,:,:,:]
    targe1 = label1[random_integer,:,:,:]
    sample2 = img2[random_integer,:,:,:]
    targe2 = label2[random_integer,:,:,:]
    sample3 = img3[random_integer,:,:,:]
    targe3 = label3[random_integer,:,:,:]
    sample4 = img4[random_integer,:,:,:]
    targe4 = label4[random_integer,:,:,:]
    sample5 = img5[random_integer,:,:,:]
    targe5 = label5[random_integer,:,:,:]
    #為了展示正規化
    sample1 = easy_nor(sample1)
    sample2 = easy_nor(sample2)
    sample3 = easy_nor(sample3)
    sample4 = easy_nor(sample4)
    sample5 = easy_nor(sample5)
    
    for idx in range(sample1.shape[-1]):
        show1 = easy_show(sample1, idx)
        show_IL1 = show1.copy()
        show2 = easy_show(sample2, idx)
        show_IL2 = show2.copy()
        show3 = easy_show(sample3, idx)
        show4 = easy_show(sample4, idx)
        show5 = easy_show(sample5, idx)
        
        y_th1 = esay_canny(targe1, idx)
        y_th2 = esay_canny(targe2, idx)
        y_th3 = esay_canny(targe3, idx)
        y_th4 = esay_canny(targe4, idx)
        y_th5 = esay_canny(targe5, idx)

        y_c, x_c = np.where(y_th1>0)
        if len(y_c) > 0:    
            show_IL1[y_c,x_c,0] = 255
            show_IL1[y_c,x_c,1] = 0
            show_IL1[y_c,x_c,2] = 0
            
        y_c, x_c = np.where(y_th2>0)
        if len(y_c) > 0:    
            show_IL2[y_c,x_c,0] = 255
            show_IL2[y_c,x_c,1] = 0
            show_IL2[y_c,x_c,2] = 0

#         y_c, x_c = np.where(y_th3>0)
#         if len(y_c) > 0:    
#             show3[y_c,x_c,0] = 255
#             show3[y_c,x_c,1] = 0
#             show3[y_c,x_c,2] = 0
            
#         y_c, x_c = np.where(y_th4>0)
#         if len(y_c) > 0:    
#             show4[y_c,x_c,0] = 255
#             show4[y_c,x_c,1] = 0
#             show4[y_c,x_c,2] = 0
            
#         y_c, x_c = np.where(y_th5>0)
#         if len(y_c) > 0:    
#             show5[y_c,x_c,0] = 255
#             show5[y_c,x_c,1] = 0
#             show5[y_c,x_c,2] = 0
                
        plt.style.use('default') #使用背景色，繪圖風格
        plt.figure(figsize=(15, 15)) #show 2view
        plt.subplot(1,7,1)
        plt.imshow(show1)
        plt.title('Image', fontsize=15)
        plt.axis('off')
        plt.subplot(1,7,2)
        plt.imshow(show_IL1)
        plt.title('o_Label', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,7,3)
        plt.imshow(show2)
        plt.title('resample', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,7,4)
        plt.imshow(show_IL2)
        plt.title('resample_Label', fontsize=15)
        plt.axis('off')
        plt.subplot(1,7,5)
        plt.imshow(show3)
        plt.title('aug', fontsize=15)
        plt.axis('off')     
        plt.subplot(1,7,6)
        plt.imshow(show4)
        plt.title('low', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,7,7)
        plt.imshow(show5)
        plt.title('auglow', fontsize=15)
        plt.axis('off') 
        plt.show()
        
def plot_multi_view2(img1, label1, img2, label2, img3, label3, img4, label4, img5, label5):
    #先決定影像是否為正規化後，是的話就不用再做正規化, 先決定影像是否為4d，畫三張，原圖,label,疊圖
    #dim=3為標註123放在一起，dim=4為一標註一層，統一把標註轉為1層1個吧
    # 生成 1 到 100 之間的隨機正整數（包括 1 和 100）
    random_integer = random.randint(0, img1.shape[0]-1)
    sample1 = img1[random_integer,:,:,:]
    targe1 = label1[random_integer,:,:,:]
    sample2 = img2[random_integer,:,:,:]
    targe2 = label2[random_integer,:,:,:]
    sample3 = img3[random_integer,:,:,:]
    targe3 = label3[random_integer,:,:,:]
    sample4 = img4[random_integer,:,:,:]
    targe4 = label4[random_integer,:,:,:]
    sample5 = img5[random_integer,:,:,:]
    targe5 = label5[random_integer,:,:,:]
    #為了展示正規化
    sample1 = easy_nor(sample1)
    sample2 = easy_nor(sample2)
    sample3 = easy_nor(sample3)
    sample4 = easy_nor(sample4)
    sample5 = easy_nor(sample5)
    
    for idx in range(sample1.shape[-1]):
        show1 = easy_show(sample1, idx)
        show_IL1 = show1.copy()
        show2 = easy_show(sample2, idx)
        show_IL2 = show2.copy()
        show3 = easy_show(sample3, idx)
        show4 = easy_show(sample4, idx)
        show5 = easy_show(sample5, idx)
        
        y_th1 = esay_canny(targe1, idx)
        y_th2 = esay_canny(targe2, idx)
        y_th3 = esay_canny(targe3, idx)
        y_th4 = esay_canny(targe4, idx)
        y_th5 = esay_canny(targe5, idx)

        y_c, x_c = np.where(y_th1>0)
        if len(y_c) > 0:    
            show_IL1[y_c,x_c,0] = 255
            show_IL1[y_c,x_c,1] = 0
            show_IL1[y_c,x_c,2] = 0
            
        y_c, x_c = np.where(y_th2>0)
        if len(y_c) > 0:    
            show_IL2[y_c,x_c,0] = 255
            show_IL2[y_c,x_c,1] = 0
            show_IL2[y_c,x_c,2] = 0

#         y_c, x_c = np.where(y_th3>0)
#         if len(y_c) > 0:    
#             show3[y_c,x_c,0] = 255
#             show3[y_c,x_c,1] = 0
#             show3[y_c,x_c,2] = 0
            
#         y_c, x_c = np.where(y_th4>0)
#         if len(y_c) > 0:    
#             show4[y_c,x_c,0] = 255
#             show4[y_c,x_c,1] = 0
#             show4[y_c,x_c,2] = 0
            
#         y_c, x_c = np.where(y_th5>0)
#         if len(y_c) > 0:    
#             show5[y_c,x_c,0] = 255
#             show5[y_c,x_c,1] = 0
#             show5[y_c,x_c,2] = 0
                
        plt.style.use('default') #使用背景色，繪圖風格
        plt.figure(figsize=(15, 15)) #show 2view
        plt.subplot(1,4,1)
        plt.imshow(show1)
        plt.title('Image', fontsize=15)
        plt.axis('off')
        plt.subplot(1,4,2)
        plt.imshow(show_IL1)
        plt.title('o_Label', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,4,3)
        plt.imshow(show2)
        plt.title('resample', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,4,4)
        plt.imshow(show_IL2)
        plt.title('resample_Label', fontsize=15)
        plt.show()

In [None]:
#透過 SingleThreadedAugmenter ，這裡融入各種augmentation
import inspect
import multiprocessing
import os
import shutil
import sys
from copy import deepcopy
from datetime import datetime
from time import time, sleep
from typing import Union, Tuple, List

import numpy as np
import torch
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
    ContrastAugmentationTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnunetv2.inference.export_prediction import export_prediction_from_softmax, resample_and_save
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, predict_sliding_window_return_logits
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
    ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
    DownsampleSegForDSTransform2
from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
    LimitedLenWrapper
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
    ConvertSegmentationToRegionsTransform
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \
    Convert3DTo2DTransform
from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.file_path_utilities import should_i_save_to_file, check_workers_busy
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from sklearn.model_selection import KFold
from torch import autocast, nn
from torch import distributed as dist
from torch.cuda import device_count
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

In [None]:
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(patch_size):
    """
    This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
    """
    dim = len(patch_size)
    # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
    if dim == 2:
        do_dummy_2d_data_aug = False
        # todo revisit this parametrization
        if max(patch_size) / min(patch_size) > 1.5:
            rotation_for_DA = {
                'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
                'y': (0, 0),
                'z': (0, 0)
            }
        else:
            rotation_for_DA = {
                'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
                'y': (0, 0),
                'z': (0, 0)
            }
        mirror_axes = (0, 1)
    elif dim == 3:
        # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
        # order of the axes is determined by spacing, not image size
        do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
        if do_dummy_2d_data_aug:
            # why do we rotate 180 deg here all the time? We should also restrict it
            rotation_for_DA = {
                'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
                'y': (0, 0),
                'z': (0, 0)
            }
        else:
            rotation_for_DA = {
                'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
                'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
                'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
            }
        mirror_axes = (0, 1, 2)
    else:
        raise RuntimeError()

    # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
    #  old nnunet for now)
    initial_patch_size = get_patch_size(patch_size[-dim:],
                                        *rotation_for_DA.values(),
                                        (0.85, 1.25))
    if do_dummy_2d_data_aug:
        initial_patch_size[0] = patch_size[0]

    return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes

configure_rotation_dummyDA_mirroring_and_inital_patch_size((16,32,32)) 

({'x': (-0.5235987755982988, 0.5235987755982988),
  'y': (-0.5235987755982988, 0.5235987755982988),
  'z': (-0.5235987755982988, 0.5235987755982988)},  
  
 False,
 array([35, 51, 42]),
 (0, 1, 2))

In [None]:
def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
                            rotation_for_DA: dict,
                            mirror_axes: Tuple[int, ...],
                            do_dummy_2d_data_aug: bool,
                            order_resampling_data: int = 3,
                            order_resampling_seg: int = 1,
                            border_val_seg: int = -1,
                            use_mask_for_norm: List[bool] = None,
                            is_cascaded: bool = False,
                            foreground_labels: Union[Tuple[int, ...], List[int]] = None,
                            regions: List[Union[List[int], Tuple[int, ...], int]] = None,
                            ignore_label: int = None) -> AbstractTransform:
    tr_transforms = []
    if do_dummy_2d_data_aug:
        ignore_axes = (0,)
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size
        ignore_axes = None

    tr_transforms.append(SpatialTransform(
        patch_size_spatial, patch_center_dist_from_border=None,
        do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
        do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'],
        p_rot_per_axis=1,  # todo experiment with this
        do_scale=True, scale=(0.7, 1.4),
        border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data,
        border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg,
        random_crop=False,  # random cropping is part of our dataloaders
        p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2,
        independent_scale_for_each_axis=False  # todo experiment with this
    ))

    if do_dummy_2d_data_aug:
        tr_transforms.append(Convert2DTo3DTransform())

    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
    tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
                                               p_per_channel=0.5))
    tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
    tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
    # tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
    #                                                     p_per_channel=0.5,
    #                                                     order_downsample=0, order_upsample=3, p_per_sample=0.25,
    #                                                     ignore_axes=ignore_axes))
    tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))
    tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))

    if mirror_axes is not None and len(mirror_axes) > 0:
        tr_transforms.append(MirrorTransform(mirror_axes))

    if use_mask_for_norm is not None and any(use_mask_for_norm):
        tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
                                           mask_idx_in_seg=0, set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if is_cascaded:
        assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
        tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
        tr_transforms.append(ApplyRandomBinaryOperatorTransform(
            channel_idx=list(range(-len(foreground_labels), 0)),
            p_per_sample=0.4,
            key="data",
            strel_size=(1, 8),
            p_per_label=1))
        tr_transforms.append(
            RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                channel_idx=list(range(-len(foreground_labels), 0)),
                key="data",
                p_per_sample=0.2,
                fill_with_other_class_p=0,
                dont_do_if_covers_more_than_x_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        # the ignore label must also be converted
        tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
                                                                   if ignore_label is not None else regions,
                                                                   'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms

In [None]:
def set_right(img):
    img = np.transpose(img, (1, 2, 0))
    img = np.flip(img,0)
    img = np.flip(img,1)
    img = np.flip(img,-1)
    return img

In [None]:
def predict_from_raw_data(ID,
                          list_of_lists_or_source_folder: Union[str, List[List[str]]],
                          Mask_list_of_lists_or_Mask_folder: Union[str, List[List[str]]],
                          output_folder: str,
                          model_training_output_dir: str,
                          use_folds: Union[Tuple[int, ...], str] = None,
                          tile_step_size: float = 0.5,
                          use_gaussian: bool = True,
                          use_mirroring: bool = True,
                          perform_everything_on_gpu: bool = True,
                          verbose: bool = True,
                          save_probabilities: bool = False,
                          overwrite: bool = True,
                          checkpoint_name: str = 'checkpoint_best.pth',
                          num_processes_preprocessing: int = default_num_processes,
                          num_processes_segmentation_export: int = default_num_processes,
                          folder_with_segs_from_prev_stage: str = None,
                          num_parts: int = 1,
                          part_id: int = 0,
                          desired_gpu_index : int = 2,
                          device: torch.device = torch.device('cuda')):
    print("\n#######################################################################\nPlease cite the following paper "
          "when using nnU-Net:\n"
          "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
          "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
          "Nature methods, 18(2), 203-211.\n#######################################################################\n")

    # 假設你想要在某個特定 GPU 上執行（例如GPU 1，編號從0開始）
    #desired_gpu_index = 0  # 修改此處來指定你希望使用的 GPU 編號

    # 檢查是否為 CUDA 設備，並指定 GPU 編號
    if device.type == 'cuda':
        device = torch.device(type='cuda', index=desired_gpu_index)  # 根據 desired_gpu_index 設定具體的 GPU

    if device.type != 'cuda':
        perform_everything_on_gpu = False
    
    print('device:', device, ' desired_gpu_index:', desired_gpu_index)

    # let's store the input arguments so that its clear what was used to generate the prediction
    my_init_kwargs = {}
    for k in inspect.signature(predict_from_raw_data).parameters.keys():
        my_init_kwargs[k] = locals()[k]
    my_init_kwargs = deepcopy(my_init_kwargs)  # let's not unintentionally change anything in-place. Take this as a
    # safety precaution.
    recursive_fix_for_json_export(my_init_kwargs)
    maybe_mkdir_p(output_folder)
    save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))

    if use_folds is None:
        use_folds = auto_detect_available_folds(model_training_output_dir, checkpoint_name)

    # load all the stuff we need from the model_training_output_dir
    # 這邊獲得都是模型的參數
    parameters, configuration_manager, inference_allowed_mirroring_axes, \
    plans_manager, dataset_json, network, trainer_name = \
        load_what_we_need(model_training_output_dir, use_folds, checkpoint_name)
    
    #這邊先不用到
    """
    # check if we need a prediction from the previous stage
    if configuration_manager.previous_stage_name is not None:
        if folder_with_segs_from_prev_stage is None:
            print(f'WARNING: The requested configuration is a cascaded model and requires predctions from the '
                  f'previous stage! folder_with_segs_from_prev_stage was not provided. Trying to run the '
                  f'inference of the previous stage...')
            folder_with_segs_from_prev_stage = join(output_folder,
                                                    f'prediction_{configuration_manager.previous_stage_name}')
            predict_from_raw_data(list_of_lists_or_source_folder,
                                  folder_with_segs_from_prev_stage,
                                  get_output_folder(plans_manager.dataset_name,
                                                    trainer_name,
                                                    plans_manager.plans_name,
                                                    configuration_manager.previous_stage_name),
                                  use_folds, tile_step_size, use_gaussian, use_mirroring, perform_everything_on_gpu,
                                  verbose, False, overwrite, checkpoint_name,
                                  num_processes_preprocessing, num_processes_segmentation_export, None,
                                  num_parts=num_parts, part_id=part_id, device=device)
    """

    # sort out input and output filenames
    if isinstance(list_of_lists_or_source_folder, str):
        list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder,
                                                                                   dataset_json['file_ending'])
        Mask_list_of_lists_or_Mask_folder = create_lists_from_splitted_dataset_folder(Mask_list_of_lists_or_Mask_folder,
                                                                                   dataset_json['file_ending'])
    
    print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder')
    list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]
    caseids = [os.path.basename(i[0])[:-(len(dataset_json['file_ending']) + 5)] for i in list_of_lists_or_source_folder]
    print(f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)')
    print(f'There are {len(caseids)} cases that I would like to predict')
    print('選中的case為:', str(ID))
    print('list_of_lists_or_source_folder example:', list_of_lists_or_source_folder[ID-1])
    print('Mask_list_of_lists_or_Mask_folder:', type(Mask_list_of_lists_or_Mask_folder), Mask_list_of_lists_or_Mask_folder[ID-1])
    
    #這裡只留下選定的case
    list_of_lists_or_source_folder = [list_of_lists_or_source_folder[ID-1]]
    Mask_list_of_lists_or_Mask_folder = [Mask_list_of_lists_or_Mask_folder[ID-1]]
    print('Mask_list_of_lists_or_Mask_folder:', type(Mask_list_of_lists_or_Mask_folder), Mask_list_of_lists_or_Mask_folder)
    caseids = [caseids[ID-1]]
    
    output_filename_truncated = [join(output_folder, i) for i in caseids]
    seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + dataset_json['file_ending']) if
                                 folder_with_segs_from_prev_stage is not None else None for i in caseids]
    # remove already predicted files form the lists
    if not overwrite:
        tmp = [isfile(i + dataset_json['file_ending']) for i in output_filename_truncated]
        not_existing_indices = [i for i, j in enumerate(tmp) if not j]

        output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices]
        list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices]
        seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices]
        print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. '
              f'That\'s {len(not_existing_indices)} cases.')
        # caseids = [caseids[i] for i in not_existing_indices]
    print('list_of_lists_or_source_folder:', list_of_lists_or_source_folder)
    print('seg_from_prev_stage_files:', seg_from_prev_stage_files)
    # placing this into a separate function doesnt make sense because it needs so many input variables...
    preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
    # hijack batchgenerators, yo
    # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This
    # way we don't have to reinvent the wheel here.
    
    print('開始前處裡!!!')
    num_processes = max(1, min(num_processes_preprocessing, len(list_of_lists_or_source_folder)))
    print('num_processes:', num_processes)
    #Mask_list_of_lists_or_Mask_folder就是原本的label資料夾
    ppa = PreprocessAdapter(list_of_lists_or_source_folder, Mask_list_of_lists_or_Mask_folder, preprocessor,
                            output_filename_truncated, plans_manager, dataset_json,
                            configuration_manager, num_processes)

    #進到這裡的生成器基本上等於做augmentation，因為MultiThreadedAugmenter會調用gpu，這邊用SingleThreadedAugmenter就好
    #mta = MultiThreadedAugmenter(ppa, NumpyToTensor(), num_processes, 1, None, pin_memory=device.type == 'cuda')
    mta = SingleThreadedAugmenter(ppa, NumpyToTensor())
    
    #以下增加augmentation使用的增強器
    patch_size = (160 , 512, 512)

    rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
        configure_rotation_dummyDA_mirroring_and_inital_patch_size(patch_size)
    
    print('rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes:', rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes)

    # training pipeline
    tr_transforms = get_training_transforms(
        (1, 1, initial_patch_size[-1], initial_patch_size[1], initial_patch_size[0]), rotation_for_DA, mirror_axes, do_dummy_2d_data_aug=False,
        order_resampling_data=3, order_resampling_seg=1)

    mta_tr = SingleThreadedAugmenter(ppa, tr_transforms)
    #mta_tr = SingleThreadedAugmenter(ppa, NumpyToTensor())
    
    # precompute gaussian
#     inference_gaussian = torch.from_numpy(
#         compute_gaussian(configuration_manager.patch_size)).half()
#     if perform_everything_on_gpu:
#         inference_gaussian = inference_gaussian.to(device)
#     print('inference_gaussian.shape:', inference_gaussian.shape)

    # num seg heads is needed because we need to preallocate the results in predict_sliding_window_return_logits
    label_manager = plans_manager.get_label_manager(dataset_json)
    num_seg_heads = label_manager.num_segmentation_heads
    #num_seg_heads 這邊為 0背景 1.動脈瘤，所以為2
    #print('num_seg_heads:', num_seg_heads)

    # go go go
    # spawn allows the use of GPU in the background process in case somebody wants to do this. Not recommended. Trust me.
    # export_pool = multiprocessing.get_context('spawn').Pool(num_processes_segmentation_export)
    # export_pool = multiprocessing.Pool(num_processes_segmentation_export)
    print('go go go!!!')
    with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
        #network = network.to(device)
        #network = network

        r = []
        with torch.no_grad():
            for preprocessed, nii_path, label_path in zip(mta, list_of_lists_or_source_folder, Mask_list_of_lists_or_Mask_folder):
                start_time = time.time()
                #print('preprocessed:', preprocessed)
                data = preprocessed['data']
                print('data.shape:', data.shape) #data.shape: torch.Size([2, 136, 490, 490]))
                img_resample = data[0,:,:,:].numpy().copy()
                label_resample = data[1,:,:,:].numpy().copy()
                transposed_img_resample = set_right(img_resample)
                transposed_label_resample = set_right(label_resample)
                
                data_aug = preprocessed['data']
                img_aug = data_aug[0,:,:,:].numpy().copy()
                label_aug = data_aug[1,:,:,:].numpy().copy()
                transposed_img_aug = set_right(img_aug)
                transposed_label_aug = set_right(label_aug)
                
                #這邊切出image跟label
                #讀取nifti
                img_nii = nib.load(str(nii_path[0]))
                img_o = np.array(img_nii.dataobj) #讀出label的array矩陣      #256*256*22   
                img_o = data_translate(img_o, img_nii)
                
                label_nii = nib.load(str(label_path[0]))
                label_o = np.array(label_nii.dataobj) #讀出label的array矩陣      #256*256*22   
                label_o = data_translate(label_o, label_nii)
                
                #看出resample完成後影像的shape跟原始的sample
                print('original size:', img_o.shape, ' resample size:', transposed_img_resample.shape)
                
                #原始比resample大，resample補0
                if img_o.shape[0] > transposed_img_resample.shape[0] or img_o.shape[2] > transposed_img_resample.shape[2]:
                    # 計算需要補齊的大小:tensorA = torch.randn(1, 127, 512, 512)  # 假設是 tensorA
                    pad_height = (img_o.shape[0] - transposed_img_resample.shape[0])  # 需要補齊的高度 (上和下)
                    pad_width = (img_o.shape[1] - transposed_img_resample.shape[1])  # 需要補齊的寬度 (左和右)
                    pad_depth = (img_o.shape[2] - transposed_img_resample.shape[2])  # 需要補齊的深度 (前和後)

                    # 計算每一維的補齊值
                    # pad順序為 (左, 右, 上, 下, 前, 後)
                    padding = ((pad_height // 2, pad_height - pad_height // 2),  # 高度（上下）
                               (pad_width // 2, pad_width - pad_width // 2),  # 深度（前後）
                               (pad_depth // 2, pad_depth - pad_depth // 2))  # 寬度（左右）

                    # 使用 F.pad 進行補齊
                    transposed_img_resample = np.pad(transposed_img_resample, pad_width=padding, mode='constant', constant_values=0)
                    transposed_label_resample = np.pad(transposed_label_resample, pad_width=padding, mode='constant', constant_values=0)
                    
                    transposed_img_aug = np.pad(transposed_img_aug, pad_width=padding, mode='constant', constant_values=0)
                    transposed_label_aug = np.pad(transposed_label_aug, pad_width=padding, mode='constant', constant_values=0)                    
                else:
                    pad_height = (transposed_img_resample.shape[0] - img_o.shape[0])  # 需要補齊的高度 (上和下)
                    pad_width = (transposed_img_resample.shape[1] - img_o.shape[1])  # 需要補齊的寬度 (左和右)
                    pad_depth = (transposed_img_resample.shape[2] - img_o.shape[2])  # 需要補齊的深度 (前和後)

                    # 計算每一維的補齊值
                    # pad順序為 (左, 右, 上, 下, 前, 後)
                    padding = ((pad_height // 2, pad_height - pad_height // 2),  # 高度（上下）
                               (pad_width // 2, pad_width - pad_width // 2),  # 深度（前後）
                               (pad_depth // 2, pad_depth - pad_depth // 2))  # 寬度（左右）

                    # 使用 F.pad 進行補齊
                    img_o = np.pad(img_o, pad_width=padding, mode='constant', constant_values=0)
                    label_o = np.pad(label_o, pad_width=padding, mode='constant', constant_values=0)
                    
                print('original size:', img_o.shape, ' after padding size:', transposed_img_resample.shape)

                ofile = preprocessed['ofile']
                print(f'\nPredicting {os.path.basename(ofile)}:')
                print(f'perform_everything_on_gpu: {perform_everything_on_gpu}')
                print('configuration_manager.patch_size:', configuration_manager.patch_size)
                
                properties = preprocessed['data_properites'] #組回nifti的參數
                
                #以下使用依照label取框的程式來畫圖
                # function test:
                image_crop_stack, label_crop_stack, image_crop_resample, label_crop_resample, image_crop_aug, label_crop_aug, image_crop_low, label_crop_low, image_crop_auglow, label_crop_auglow = rand_crop_sampling(img_o, label_o,
                                                                                                                                                                                                                        transposed_img_resample, transposed_label_resample, 
                                                                                                                                                                                                                        transposed_img_aug, transposed_label_aug,
                                                                                                                                                                                                                        transposed_img_resample, transposed_label_resample,
                                                                                                                                                                                                                        transposed_img_resample, transposed_label_resample,
                                                                                                                                                                                                                        lesion_lb=1, 
                                                                                                                                                                                                                        nag_sample_mask=transposed_label_resample>0, 
                                                                                                                                                                                                                        size=(32,32,16), sample_num=1, pos_rate=1)
                
                print("image_crop_stack shape =", image_crop_stack.shape)
                print("label_crop_stack shape =", label_crop_stack.shape, label_crop_stack.min(), label_crop_stack.max())
                
                plot_multi_view2(image_crop_stack, label_crop_stack, 
                                 image_crop_resample, label_crop_resample, 
                                 image_crop_aug, label_crop_aug, 
                                 image_crop_low, label_crop_low, 
                                 image_crop_auglow, label_crop_auglow
                                 )


In [None]:
def predict_entry_point():
    import argparse
    parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
                                                 'you want to manually specify a folder containing a trained nnU-Net '
                                                 'model. This is useful when the nnunet environment variables '
                                                 '(nnUNet_results) are not set.')
    parser.add_argument('-i', type=str, required=True,
                        help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
                             'File endings must be the same as the training dataset!')
    parser.add_argument('-v', type=str, required=True,
                        help='input vessel folder. Remember to use the correct channel numberings for your files (_0000 etc). '
                             'File endings must be the same as the training dataset!')
    parser.add_argument('-o', type=str, required=True,
                        help='Output folder. If it does not exist it will be created. Predicted segmentations will '
                             'have the same name as their source images.')
    parser.add_argument('-d', type=str, required=True,
                        help='Dataset with which you would like to predict. You can specify either dataset name or id')
    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
                        help='Plans identifier. Specify the plans in which the desired configuration is located. '
                             'Default: nnUNetPlans')
    parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
                        help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')
    parser.add_argument('-c', type=str, required=True,
                        help='nnU-Net configuration that should be used for prediction. Config must be located '
                             'in the plans specified with -p')
    parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
                        help='Specify the folds of the trained model that should be used for prediction. '
                             'Default: (0, 1, 2, 3, 4)')
    parser.add_argument('-step_size', type=float, required=False, default=0.5,
                        help='Step size for sliding window prediction. The larger it is the faster but less accurate '
                             'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
    parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
                        help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
                             'but less accurate inference. Not recommended.')
    parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
                                                               "to be a good listener/reader.")
    parser.add_argument('--save_probabilities', action='store_true',
                        help='Set this to export predicted class "probabilities". Required if you want to ensemble '
                             'multiple configurations.')
    parser.add_argument('--continue_prediction', action='store_true',
                        help='Continue an aborted previous prediction (will not overwrite existing files)')
    parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
                        help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
    parser.add_argument('-npp', type=int, required=False, default=3,
                        help='Number of processes used for preprocessing. More is not always better. Beware of '
                             'out-of-RAM issues. Default: 3')
    parser.add_argument('-nps', type=int, required=False, default=3,
                        help='Number of processes used for segmentation export. More is not always better. Beware of '
                             'out-of-RAM issues. Default: 3')
    parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
                        help='Folder containing the predictions of the previous stage. Required for cascaded models.')
    parser.add_argument('-num_parts', type=int, required=False, default=1,
                        help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one '
                             'call predicts everything)')
    parser.add_argument('-part_id', type=int, required=False, default=0,
                        help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with '
                             'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts '
                             '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible '
                             'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)')
    parser.add_argument('-desired_gpu_index', type=int, default=2, required=False, 
                        help="This to set which GPU ID!")
    parser.add_argument('-device', type=str, default='cuda', required=False,
                        help="Use this to set the device the inference should run with. Available options are 'cuda' "
                             "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
                             "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")   

    args = parser.parse_args()
    args.f = [i if i == 'all' else int(i) for i in args.f]

    model_folder = get_output_folder(args.d, args.tr, args.p, args.c)

    if not isdir(args.o):
        maybe_mkdir_p(args.o)

    # slightly passive agressive haha
    assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.'

    assert args.device in ['cpu', 'cuda',
                           'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
    if args.device == 'cpu':
        # let's allow torch to use hella threads
        import multiprocessing
        torch.set_num_threads(multiprocessing.cpu_count())
        device = torch.device('cpu')
    elif args.device == 'cuda':
        # multithreading in torch doesn't help nnU-Net if run on GPU
        torch.set_num_threads(1)
        torch.set_num_interop_threads(1)
        device = torch.device('cuda')
    else:
        device = torch.device('mps')

    predict_from_raw_data(1,
                          args.i,
                          args.v,
                          args.o,
                          model_folder,
                          args.f,
                          args.step_size,
                          use_gaussian=True,
                          use_mirroring=not args.disable_tta,
                          perform_everything_on_gpu=True,
                          verbose=args.verbose,
                          save_probabilities=args.save_probabilities,
                          overwrite=not args.continue_prediction,
                          checkpoint_name=args.chk,
                          num_processes_preprocessing=args.npp,
                          num_processes_segmentation_export=args.nps,
                          folder_with_segs_from_prev_stage=args.prev_stage_predictions,
                          num_parts=args.num_parts,
                          part_id=args.part_id,
                          desired_gpu_index = args.desired_gpu_index,
                          device=device)

這邊vessel改成label，去展示疊label的情況

In [None]:
if __name__ == "__main__":
    from multiprocessing import Pool
    #最前面數字可以指定先在要看哪個case，至少1以上
    predict_from_raw_data(3,
                          '/data/chuan/nnUNet/nnUNet_raw/Dataset058_DeepAneurysm/Normalized_Image_External_Test',
                          '/data/chuan/nnUNet/nnUNet_raw/Dataset058_DeepAneurysm/Label_External_Test/',
                          '/data/chuan/nnUNet/nnUNet_inference/Dataset058_DeepAneurysm/3d_fullres/nnResUNet/External_Test_tsetshow',
                          '/data/chuan/nnUNet/nnUNet_results/Dataset058_DeepAneurysm/nnUNetTrainer__nnUNetPlans__3d_fullres',
                          (1,),
                          0.25,
                          use_gaussian=True,
                          use_mirroring=False,
                          perform_everything_on_gpu=True,
                          verbose=True,
                          save_probabilities=False,
                          overwrite=False,
                          checkpoint_name='checkpoint_best.pth',
                          num_processes_preprocessing=3,
                          num_processes_segmentation_export=3,
                          desired_gpu_index = 0,
                          )