## 測試nnUNet training code
想要看出經過augmentation後，train的input長什麼樣子

In [1]:
import os
import socket
from typing import Union, Optional

import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn

In [2]:
def find_free_network_port() -> int:
    """Finds a free port on localhost.

    It is useful in single-node training when we don't want to connect to a real main node but have to set the
    `MASTER_PORT` environment variable.
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("", 0))
    port = s.getsockname()[1]
    s.close()
    return port

In [3]:
def get_trainer_from_args(dataset_name_or_id: Union[int, str],
                          configuration: str,
                          fold: int,
                          trainer_name: str = 'nnUNetTrainer',
                          plans_identifier: str = 'nnUNetPlans',
                          use_compressed: bool = False,
                          device: torch.device = torch.device('cuda')):
    # load nnunet class and do sanity checks
    nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                trainer_name, 'nnunetv2.training.nnUNetTrainer')
    if nnunet_trainer is None:
        raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
                           f'nnunetv2.training.nnUNetTrainer ('
                           f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
                           f'else, please move it there.')
    assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
                                                    'nnUNetTrainer'

    # handle dataset input. If it's an ID we need to convert to int from string
    if dataset_name_or_id.startswith('Dataset'):
        pass
    else:
        try:
            dataset_name_or_id = int(dataset_name_or_id)
        except ValueError:
            raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
                             f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
                             f'input: {dataset_name_or_id}')

    # initialize nnunet trainer
    preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
    plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
    plans = load_json(plans_file)
    dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
    nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
                                    dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device)
    return nnunet_trainer

In [4]:
def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
                          pretrained_weights_file: str = None):
    if continue_training:
        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
        if not isfile(expected_checkpoint_file):
            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
        # special case where --c is used to run a previously aborted validation
        if not isfile(expected_checkpoint_file):
            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
        if not isfile(expected_checkpoint_file):
            print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
                               f"continue from. Starting a new training...")
    elif validation_only:
        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
        if not isfile(expected_checkpoint_file):
            raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
    else:
        if pretrained_weights_file is not None:
            load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
        expected_checkpoint_file = None

    if expected_checkpoint_file is not None:
        nnunet_trainer.load_checkpoint(expected_checkpoint_file)

In [5]:
def setup_ddp(rank, world_size):
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [6]:
def cleanup_ddp():
    dist.destroy_process_group()

In [7]:
def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, pretrained_weights, npz, world_size):
    setup_ddp(rank, world_size)
    torch.cuda.set_device(torch.device('cuda', dist.get_rank()))

    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,
                                           use_compressed)

    if disable_checkpointing:
        nnunet_trainer.disable_checkpointing = disable_checkpointing

    assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'

    maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)

    if torch.cuda.is_available():
        cudnn.deterministic = False
        cudnn.benchmark = True

    if not val:
        nnunet_trainer.run_training()

    nnunet_trainer.perform_actual_validation(npz)
    cleanup_ddp()

In [8]:
def run_training(dataset_name_or_id: Union[str, int],
                 configuration: str, fold: Union[int, str],
                 trainer_class_name: str = 'nnUNetTrainer',
                 plans_identifier: str = 'nnUNetPlans',
                 pretrained_weights: Optional[str] = None,
                 num_gpus: int = 1,
                 use_compressed_data: bool = False,
                 export_validation_probabilities: bool = False,
                 continue_training: bool = False,
                 only_run_validation: bool = False,
                 disable_checkpointing: bool = False,
                 device: torch.device = torch.device('cuda')):
    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if num_gpus > 1:
        assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"

        os.environ['MASTER_ADDR'] = 'localhost'
        if 'MASTER_PORT' not in os.environ.keys():
            port = str(find_free_network_port())
            print(f"using port {port}")
            os.environ['MASTER_PORT'] = port  # str(port)

        mp.spawn(run_ddp,
                 args=(
                     dataset_name_or_id,
                     configuration,
                     fold,
                     trainer_class_name,
                     plans_identifier,
                     use_compressed_data,
                     disable_checkpointing,
                     continue_training,
                     only_run_validation,
                     pretrained_weights,
                     export_validation_probabilities,
                     num_gpus),
                 nprocs=num_gpus,
                 join=True)
    else:
        nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                                               plans_identifier, use_compressed_data, device=device)

        #已經確認load pretrain weights，所以先初始化一次
        if pretrained_weights is not None:
            nnunet_trainer.initialize()
            #因為已經做過初始化，所以

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        if nnunet_trainer.network is None:
            print("Network is not initialized correctly!")
        else:
            print("Network is initialized successfully.")

        assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)

        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        if not only_run_validation:
            nnunet_trainer.run_training()

        nnunet_trainer.perform_actual_validation(export_validation_probabilities)

In [9]:
def run_training_entry():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset_name_or_id', type=str,
                        help="Dataset name or ID to train with")
    parser.add_argument('configuration', type=str,
                        help="Configuration that should be trained")
    parser.add_argument('fold', type=str,
                        help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')
    parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
                        help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')
    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
                        help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')
    parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
                        help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '
                             'be used when actually training. Beta. Use with caution.')
    parser.add_argument('-num_gpus', type=int, default=1, required=False,
                        help='Specify the number of GPUs to use for training')
    parser.add_argument("--use_compressed", default=False, action="store_true", required=False,
                        help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed "
                             "data is much more CPU and (potentially) RAM intensive and should only be used if you "
                             "know what you are doing")
    parser.add_argument('--npz', action='store_true', required=False,
                        help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '
                             'segmentations). Needed for finding the best ensemble.')
    parser.add_argument('--c', action='store_true', required=False,
                        help='[OPTIONAL] Continue training from latest checkpoint')
    parser.add_argument('--val', action='store_true', required=False,
                        help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')
    parser.add_argument('--disable_checkpointing', action='store_true', required=False,
                        help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '
                             'you dont want to flood your hard drive with checkpoints.')
    parser.add_argument('-device', type=str, default='cuda', required=False,
                    help="Use this to set the device the training should run with. Available options are 'cuda' "
                         "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
                         "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
    args = parser.parse_args()

    assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
    if args.device == 'cpu':
        # let's allow torch to use hella threads
        import multiprocessing
        torch.set_num_threads(multiprocessing.cpu_count())
        device = torch.device('cpu')
    elif args.device == 'cuda':
        # multithreading in torch doesn't help nnU-Net if run on GPU
        torch.set_num_threads(1)
        torch.set_num_interop_threads(1)
        device = torch.device('cuda')
    else:
        device = torch.device('mps')

    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
                 args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing,
                 device=device)

In [10]:
if __name__ == '__main__':
    dataset_name_or_id = '74'
    configuration = '3d_fullres'
    fold = '0'
    num_gpus = 1
    trainer_class_name = 'nnUNetTrainer'
    plans_identifier = 'nnUNetPlans'
    use_compressed_data = False
    device = torch.device('cuda')
    pretrained_weights = None
    
    #run_training(dataset_name_or_id, configuration, fold, args.tr, args.p, args.pretrained_weights,
    #             args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing,
    #             device=device)

    #run_training(dataset_name_or_id, configuration, fold)
    
    #下面把run_training完全展開來看
    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e
    
    if num_gpus > 1:
        assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"

        os.environ['MASTER_ADDR'] = 'localhost'
        if 'MASTER_PORT' not in os.environ.keys():
            port = str(find_free_network_port())
            print(f"using port {port}")
            os.environ['MASTER_PORT'] = port  # str(port)
            
    else:
        nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                                               plans_identifier, use_compressed_data, device=device)

        #已經確認load pretrain weights，所以先初始化一次
        if pretrained_weights is not None:
            nnunet_trainer.initialize()
            #因為已經做過初始化，所以
        else:
            plans_manager_i, dataset_json_i, configuration_manager_i, num_input_channels_i, network_i = nnunet_trainer.initialize_network_look()
        a = aa
           
        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        #if not only_run_validation:
        #    nnunet_trainer.run_training()

        #nnunet_trainer.perform_actual_validation(export_validation_probabilities)
        #這邊之後就用data, target來畫圖吧!!!
        data_noaug, target_noaug, data_aug, target_aug = nnunet_trainer.run_training_ShowTraingData()

Using device: cuda:0

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################



NameError: name 'aa' is not defined

In [11]:
plans_manager_i

{'dataset_name': 'Dataset074_DeepAneurysm', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [0.6999997496604919, 0.44920000433921814, 0.44920000433921814], 'original_median_shape_after_transp': [133, 512, 512], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'configurations': {'2d': {'data_identifier': 'nnUNetPlans_2d', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 14, 'patch_size': [384, 384], 'median_image_size_in_voxels': [434.0, 402.0], 'spacing': [0.44920000433921814, 0.44920000433921814], 'normalization_schemes': ['NoNormalization'], 'use_mask_for_norm': [False], 'UNet_class_name': 'ResidualEncoderUNet', 'UNet_base_num_features': 32, 'n_conv_per_stage_encoder': [2, 2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2, 2], 'num_pool_per_axis': [6, 6], 'pool_op_kernel_sizes': [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], 'conv_kernel_sizes': [[3, 3], [3, 3], [3, 3], [3, 3], 

In [12]:
dataset_json_i

{'TASK': 'DeepAneurysm',
 'No.': '074',
 'channel_names': {'0': 'MRA_BRAIN'},
 'labels': {'background': 0, 'Aneurysm': 1},
 'numTraining': 3116,
 'file_ending': '.nii.gz',
 'val_as_test': 'n'}

In [13]:
configuration_manager_i

{'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [16, 32, 32], 'median_image_size_in_voxels': [16.0, 32.0, 32.0], 'spacing': [0.6999997496604919, 0.44920000433921814, 0.44920000433921814], 'normalization_schemes': ['NoNormalization'], 'use_mask_for_norm': [False], 'UNet_class_name': 'ResidualEncoderUNetClassifier', 'UNet_base_num_features': 32, 'n_conv_per_stage_encoder': [2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2], 'num_pool_per_axis': [2, 3, 3], 'pool_op_kernel_sizes': [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], 'conv_kernel_sizes': [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'unet_max_num_features': 512, 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_

In [14]:
num_input_channels_i

1

In [24]:
from torchinfo import summary
summary(network_i, (1200, 1, 16, 32, 32))

Layer (type:depth-idx)                                                      Output Shape              Param #
ResidualEncoderUNetClassifier                                               [1200, 2, 16, 32, 32]     --
├─UNetDecoder: 1-1                                                          --                        (recursive)
│    └─ResidualEncoder: 2-1                                                 [1200, 32, 16, 32, 32]    --
│    │    └─StackedConvBlocks: 3-1                                          [1200, 32, 16, 32, 32]    384
│    │    └─Sequential: 3-2                                                 --                        33,125,888
├─UNetDecoder: 1-2                                                          [1200, 2, 16, 32, 32]     33,126,272
│    └─ModuleList: 2-11                                                     --                        (recursive)
│    │    └─ConvTranspose3d: 3-3                                            [1200, 256, 4, 4, 4]      524,544
│    └─Mod

In [None]:
#data是訓練的資料，data: torch.Size([600, 1, 16, 32, 32])
#target因為deep supervision，所以為一個list有4層，這邊取第一層與label對在一起看

In [None]:
from scipy import ndimage, interp  # pip install scipy
import scipy.stats
from sklearn.model_selection import train_test_split  # pip install scikit-learn
from sklearn.utils import class_weight
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score, cohen_kappa_score
from skimage.util import montage  # pip install scikit-image
import skimage
import pickle
import cv2
import csv
import math
import datetime
from itertools import zip_longest
import json
import random
import numpy as np
import matplotlib
#%matplotlib notebook
%matplotlib inline
#matplotlib.use('Agg') 
import matplotlib.pyplot as plt

In [None]:
def plot_multi_view(img, label):
    #先決定影像是否為正規化後，是的話就不用再做正規化, 先決定影像是否為4d，畫三張，原圖,label,疊圖
    #dim=3為標註123放在一起，dim=4為一標註一層，統一把標註轉為1層1個吧
    # 生成 1 到 100 之間的隨機正整數（包括 1 和 100）
    random_integer = random.randint(0, img.shape[0]-1)
    sample = img[random_integer,0,:,:,:]
    targe = label[random_integer,0,:,:,:]
    
    #為了展示正規化
    sample = (((sample - np.min(sample))/(np.max(sample) - np.min(sample)))*255).copy()
    sample[sample<0] = 0
    sample[sample>255] = 255
    #rands = np.random.randint(0, img.shape[-1], num) #multi view
    for idx in range(sample.shape[0]):
        show = np.expand_dims(sample[idx,:,:], axis=-1).copy()
        y_i, x_i, z_i = show.shape
        show_one = show.copy()
        show = np.concatenate([show, show_one],2)
        show = np.concatenate([show, show_one],2)
        show = show.astype('uint8')
        show_label = np.zeros((y_i, x_i)).astype('uint8')
        show_IL = show.copy()

        y_nor1 = targe[idx,:,:]    
        y_color = (y_nor1*255).astype('uint8')
        y_th = cv2.Canny(y_color, 128, 256).copy()
        y_th = y_th.astype('uint8')

        y_c, x_c = np.where(y_nor1>0)
        if len(y_c) > 0:    
            show_label[y_c,x_c] = 1
        y_c, x_c = np.where(y_th>0)
        if len(y_c) > 0:    
            show_IL[y_c,x_c,0] = 255
            show_IL[y_c,x_c,1] = 0
            show_IL[y_c,x_c,2] = 0
                
        plt.style.use('default') #使用背景色，繪圖風格
        plt.figure(figsize=(20, 20)) #show 2view
        plt.subplot(1,3,1)
        plt.imshow(show)
        plt.title('Image', fontsize=20)
        plt.axis('off')
        plt.subplot(1,3,2)
        plt.imshow(show_label, cmap='bone')
        plt.title('Label', fontsize=20)
        plt.axis('off')
        plt.subplot(1,3,3)
        plt.imshow(show_IL)
        plt.title('ImgLabel', fontsize=20)
        plt.axis('off')        
        plt.show()

In [None]:
def set_right(img):
    img = np.transpose(img, (1, 2, 0))
    img = np.flip(img,0)
    img = np.flip(img,1)
    img = np.flip(img,-1)
    return img

def easy_nor(sample):
    sample = (((sample - np.min(sample))/(np.max(sample) - np.min(sample)))*255).copy()
    sample[sample<0] = 0
    sample[sample>255] = 255
    return sample

def easy_show(sample, idx):
    show = np.expand_dims(sample[:,:,idx], axis=-1).copy()
    #print(show.shape)
    y_i, x_i, z_i = show.shape
    show_one = show.copy()
    show = np.concatenate([show, show_one],2)
    show = np.concatenate([show, show_one],2)
    show = show.astype('uint8')
    return show

def esay_canny(targe, idx):
    y_nor1 = targe[:,:,idx]    
    y_color = (y_nor1*255).astype('uint8')
    y_th = cv2.Canny(y_color, 128, 256).copy()
    y_th = y_th.astype('uint8')
    return y_th

def plot_multi_view2(img1, label1, img2, label2, random_integer):
    #先決定影像是否為正規化後，是的話就不用再做正規化, 先決定影像是否為4d，畫三張，原圖,label,疊圖
    #dim=3為標註123放在一起，dim=4為一標註一層，統一把標註轉為1層1個吧
    # 生成 1 到 100 之間的隨機正整數（包括 1 和 100） => 改成指定
    #random_integer = random.randint(0, img1.shape[0]-1)
    sample1 = img1[random_integer,0,:,:,:]
    targe1 = label1[random_integer,0,:,:,:]
    sample2 = img2[random_integer,0,:,:,:]
    targe2 = label2[random_integer,0,:,:,:]
    
    sample1 = set_right(sample1)
    targe1 = set_right(targe1)
    sample2 = set_right(sample2)
    targe2 = set_right(targe2)

    #為了展示正規化
    sample1 = easy_nor(sample1)
    sample2 = easy_nor(sample2)
    
    for idx in range(sample1.shape[-1]):
        show1 = easy_show(sample1, idx)
        show_IL1 = show1.copy()
        show2 = easy_show(sample2, idx)
        show_IL2 = show2.copy()
        
        y_th1 = esay_canny(targe1, idx)
        y_th2 = esay_canny(targe2, idx)

        y_c, x_c = np.where(y_th1>0)
        if len(y_c) > 0:    
            show_IL1[y_c,x_c,0] = 255
            show_IL1[y_c,x_c,1] = 0
            show_IL1[y_c,x_c,2] = 0
            
        y_c, x_c = np.where(y_th2>0)
        if len(y_c) > 0:    
            show_IL2[y_c,x_c,0] = 255
            show_IL2[y_c,x_c,1] = 0
            show_IL2[y_c,x_c,2] = 0
                
        plt.style.use('default') #使用背景色，繪圖風格
        plt.figure(figsize=(15, 15)) #show 2view
        plt.subplot(1,4,1)
        plt.imshow(show1)
        plt.title('Image', fontsize=15)
        plt.axis('off')
        plt.subplot(1,4,2)
        plt.imshow(show_IL1)
        plt.title('o_Label', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,4,3)
        plt.imshow(show2)
        plt.title('aug', fontsize=15)
        plt.axis('off') 
        plt.subplot(1,4,4)
        plt.imshow(show_IL2)
        plt.title('aug_Label', fontsize=15)
        plt.show()

In [None]:
#先只保留有動脈瘤的case
img_noaug = data_noaug.numpy()
label_noaug = target_noaug[0].numpy()
img_aug = data_aug.numpy()
label_aug = target_aug[0].numpy()

# 1. 過濾標註全為 0 的張量
non_zero_indices = [i for i in range(label_noaug.shape[0]) if not np.all(label_noaug[i] == 0)]
filtered_label_noaug = label_noaug[non_zero_indices]
filtered_img_noaug = img_noaug[non_zero_indices]
filtered_label_aug = label_aug[non_zero_indices]
filtered_img_aug = img_aug[non_zero_indices]

# 2. 按標註值的總和排序
# 計算每張標註的總和，並按總和降序排序
sums = filtered_label_noaug.sum(axis=(1, 2, 3, 4))  # 計算每張的標註總和
sorted_indices = np.argsort(-sums)         # 按總和降序排序

sorted_label_noaug = filtered_label_noaug[sorted_indices]
sorted_img_noaug = filtered_img_noaug[sorted_indices]
sorted_label_aug = filtered_label_aug[sorted_indices]
sorted_img_aug = filtered_img_aug[sorted_indices]


# 輸出的目標尺寸
output_shape = (24, 42, 42)

# 計算輸入的中心點
input_shape = sorted_img_noaug.shape[2:]  # (35, 55, 42)
center = [dim // 2 for dim in input_shape]  # 中心點索引

# 計算裁剪範圍
crop_ranges = [(center[i] - output_shape[i] // 2, center[i] + output_shape[i] // 2) for i in range(3)]

# 使用切片進行裁剪
sorted_label_noaug = sorted_label_noaug[
    :,  # 保留 batch
    :,  # 保留 channel
    crop_ranges[0][0]:crop_ranges[0][1],  # 第三維
    crop_ranges[1][0]:crop_ranges[1][1],  # 第四維
    crop_ranges[2][0]:crop_ranges[2][1]   # 第五維
]


sorted_img_noaug = sorted_img_noaug[
    :,  # 保留 batch
    :,  # 保留 channel
    crop_ranges[0][0]:crop_ranges[0][1],  # 第三維
    crop_ranges[1][0]:crop_ranges[1][1],  # 第四維
    crop_ranges[2][0]:crop_ranges[2][1]   # 第五維
]


# 目標大小
target_shape = (24, 42, 42)

# 計算填充量
z_pad = (target_shape[0] - sorted_img_aug.shape[2]) // 2
y_pad = (target_shape[1] - sorted_img_aug.shape[3]) // 2
x_pad = (target_shape[2] - sorted_img_aug.shape[4]) // 2

# 如果目標大小與輸入大小不是偶數差，保證填充後大小正確
z_extra = (target_shape[0] - sorted_img_aug.shape[2]) % 2
y_extra = (target_shape[1] - sorted_img_aug.shape[3]) % 2
x_extra = (target_shape[2] - sorted_img_aug.shape[4]) % 2

# 使用 np.pad 對每張數據進行填充
sorted_label_aug = np.pad(
    sorted_label_aug,
    pad_width=((0, 0),  # 不填充 batch 維度
               (0, 0),  # 不填充通道維度
               (z_pad, z_pad + z_extra),  # z 軸填充
               (y_pad, y_pad + y_extra),  # y 軸填充
               (x_pad, x_pad + x_extra)),  # x 軸填充
    mode='constant',  # 填充模式為常數（默認填充 0）
    constant_values=0  # 填充值為 0
)


# 使用 np.pad 對每張數據進行填充
sorted_img_aug = np.pad(
    sorted_img_aug,
    pad_width=((0, 0),  # 不填充 batch 維度
               (0, 0),  # 不填充通道維度
               (z_pad, z_pad + z_extra),  # z 軸填充
               (y_pad, y_pad + y_extra),  # y 軸填充
               (x_pad, x_pad + x_extra)),  # x 軸填充
    mode='constant',  # 填充模式為常數（默認填充 0）
    constant_values=0  # 填充值為 0
)


# 處理後的結果
print("處理後的標註大小:", sorted_label_noaug.shape)
print("處理後的影像大小:", sorted_img_aug.shape)

In [None]:
plot_multi_view2(sorted_img_noaug, sorted_label_noaug, sorted_img_aug, sorted_label_aug, 298)