### 初始劃設置 / 11/07 : Out of memory


In [1]:
import SimpleITK as sitk
import numpy as np
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
import random
import torchvision.transforms as T

In [2]:
run_mode = True  # True: 執行模式, False: 除錯模式

if run_mode :
    dir_plot = False
    debug_print = False
    debug_plot = False
else:
    dir_plot = True
    debug_print = False
    debug_plot = True


### 讀取資料Luna16、製作Mask、正規化、肺實質切割

In [3]:
#/*將HU值固定在[-1024, 3071]*/
def sanitize_hu(ct):#numpy
    ct_n = ct.astype(np.float32, copy=True) #//deepcopy
    ct_n  = np.clip(ct_n , -1024.0, 3071.0) #//取上下限
    return ct_n 
def sanitize_hu_torch(ct: torch.Tensor): #torch
    ct_n = ct.float()
    ct_n = torch.clamp(ct_n, -1024.0, 3071.0) # //取上下限
    return ct_n
def window_to_uint8(ct_n, wl=-600, ww=1500):
    low, high = wl - ww/2, wl + ww/2 
    ct_n2 = np.clip(ct_n, low, high) 
    ct_n2 = ((ct_n2 - low) / ww * 255).astype('uint8') #//正規化至 [0~255]
    return ct_n2
def window_to_uint8_torch(ct_n: torch.Tensor, wl: float = -600.0, ww: float = 1500.0) -> torch.Tensor:
    low = wl - ww / 2.0 
    high = wl + ww / 2.0
    ct_n2 = torch.clamp(ct_n, low, high)
    ct_n2 = (ct_n2 - low) / ww * 255.0
    ct_n2 = ct_n2.to(torch.uint8) 
    return ct_n2

In [4]:
# --- 1. Mask 生成函式 (使用 SimpleITK 處理) ---
def get_mask(itk_image: sitk.Image, annotations: pd.DataFrame, series_uid: str) -> np.ndarray:
    """
    根據 SimpleITK 影像的元資料和標註資訊，生成更貼近原始結節形狀的 Mask。
    
    :param itk_image: SimpleITK 讀取的 CT 影像物件。
    :param annotations: 包含所有 LUNA16 標註的 DataFrame。
    :param series_uid: 當前 CT 掃描的唯一 ID。
    :return: 與 image 具有相同 shape (Z, 512, 512) 的 NumPy mask 陣列。
    """
    # 建立一個全黑 (0) 的 Mask 影像，用於累積所有結節
    final_mask_image = sitk.Image(itk_image.GetSize(), sitk.sitkUInt8)
    final_mask_image.CopyInformation(itk_image) 
    
    # 獲取體素間距 (SpacingX, SpacingY, SpacingZ)
    spacing = itk_image.GetSpacing()
    
    # 過濾出當前 series_uid 的所有標註
    series_annotations = annotations[annotations['seriesuid'] == series_uid]

    if series_annotations.empty:
        return sitk.GetArrayFromImage(final_mask_image)
    
    # 遍歷所有結節標註
    for _, row in series_annotations.iterrows():
        coordX, coordY, coordZ = row['coordX'], row['coordY'], row['coordZ']
        diameter_mm = row['diameter_mm']
        radius_mm = diameter_mm / 2.0
        
        physical_point = (coordX, coordY, coordZ)

        try:
            # 1. 獲取標註中心點的體素索引 (Voxel Index)
            voxel_index = itk_image.TransformPhysicalPointToIndex(physical_point)
            seed_point = voxel_index

            # 2. 設置區域生長濾波器
            # 閾值範圍 (例如：-100 HU 到 100 HU，這是軟組織的典型範圍)
            # 由於結節的HU值變化很大，這個閾值需要根據結節類型調整
            lower_threshold = -300 # 假設結節密度在-300以上
            upper_threshold = 300  # 假設結節密度在300以下

            # 使用 Connected Threshold 濾波器：
            connected_threshold = sitk.ConnectedThresholdImageFilter()
            connected_threshold.SetLower(lower_threshold)
            connected_threshold.SetUpper(upper_threshold)
            connected_threshold.AddSeed(seed_point) # 種子點 (X, Y, Z)

            # 執行分割
            nodule_mask = connected_threshold.Execute(itk_image)

            # 3. 確保 Mask 只是結節區域，去除雜訊 (可選：Morphological Filtering)
            # 例如：使用 BinaryClosing 填充小孔洞，讓形狀更完整
            radius_pixels = 2 # 使用一個小的半徑
            nodule_mask = sitk.BinaryFillhole(nodule_mask) # 填充空洞

            # 4. 合併到最終 Mask
            # BinaryOrImageFilter 會將兩個 Mask 影像中所有非零的像素合併
            or_filter = sitk.OrImageFilter()
            final_mask_image = or_filter.Execute(final_mask_image, nodule_mask)

        except Exception as e:
            print(f"處理 Series UID {series_uid} 標註 ({coordX}, {coordY}, {coordZ}) 時發生錯誤: {e}")
            continue

    # 將最終的 SimpleITK Mask 轉換為 NumPy 陣列 (Z, Y, X)
    mask_array = sitk.GetArrayFromImage(final_mask_image)
    return mask_array.astype(np.float32)

In [5]:
from skimage import measure
from scipy import ndimage as ndi
from skimage.filters import roberts
from skimage.measure import label,regionprops
from skimage.segmentation import clear_border
from skimage.morphology import  convex_hull_image
from skimage.morphology import disk, binary_closing
#/*ct_n2為正規化好shape = (Z, Y, X)*/
def lung_segmentation(ct_org):
    im = ct_org.copy()  #//單張單張做 -> 可用Parallel Processing加速

    #//Step1 : 二值化
    binary_thr = im <  175   #前面有正規化至0~255 
    
    #//Step2 : 清理邊界
    cleared = clear_border(binary_thr) 
    
    #//Step3 : 標記
    label_image = label(cleared)  
    
    #//Step4 : 找出最大的兩個label -> 即肺部
    areas = [r.area for r in regionprops(label_image)]
    areas.sort()
    labels = []
    if len(areas) > 2:
        for region in regionprops(label_image):
            if region.area < areas[-2]:
                for coordinates in region.coords:
                    label_image[coordinates[0], coordinates[1]] = 0
            else:
                coordinates = region.coords[0]
                labels.append(label_image[coordinates[0], coordinates[1]])
    else:
        labels = [1, 2]
        
    #//Step5 : 填充兩肺空洞
    r = label_image == labels[0]
    l = label_image == labels[1]
    r_edges = roberts(r)
    l_edges = roberts(l)
    r = ndi.binary_fill_holes(r_edges)
    l = ndi.binary_fill_holes(l_edges)
    
    #//Step6 : 平滑包覆
    r = convex_hull_image(r)
    l = convex_hull_image(l)
    
    #//Step7 : 閉合兩肺
    sum_of_lr = r + l
    binary = sum_of_lr > 0
    selem = disk(10) #radius =10
    binary_c = binary_closing(binary, selem)
    
    #//Step8 : im即為原圖ROI區域
    noise = binary_c == 0 #//binary_c為0的地方，noise為True->找出背景雜訊
    im[noise] = 0 #//把不要的地方設成0 -> 劃出我們要得ROI (肺實質)
    
    
    #//Cache
    #case_img[z_idx]=ct_org  #原圖
    #lungs_img[z_idx]=im     #原圖的ROI
           #原圖    #原圖的ROI
    return ct_org , im 

In [6]:
from concurrent.futures import ThreadPoolExecutor # 引入多執行緒執行器
from typing import Tuple
def segment_single_slice(slice_n: np.ndarray) -> np.ndarray:
    """
    執行單一切片的肺部切割，並返回切割後的影像。
    這是將要平行執行的函式。
    """
    # 這裡只取回切割後的影像，忽略 lung_mask
    _, seg_slice = lung_segmentation(slice_n)
    return seg_slice

In [7]:

# --- 1. 定義 LUNA16 資料集類別 ---
class Luna16MaskDataset(Dataset):
    def __init__(self, data_dir="data/subset0", annotations_file='annotations.csv'):
        self.max_threads = os.cpu_count() or 4
        
        self.data_dir = data_dir
        
        # 使用 glob 找到所有 .mhd 檔案
        mhd_paths = glob.glob(os.path.join(data_dir, '*.mhd'))
        self.full_paths = mhd_paths

        if not self.full_paths:
            print(f"警告：在路徑 '{data_dir}' 中找不到任何 .mhd 檔案。")

        # 讀取 LUNA16 標註 CSV 檔案
        try:
            self.annotations = pd.read_csv(annotations_file)
        except FileNotFoundError:
            print(f"錯誤：找不到標註檔案 '{annotations_file}'。Mask 將返回全零。")
            self.annotations = pd.DataFrame({'seriesuid': [], 'coordX': [], 'coordY': [], 'coordZ': [], 'diameter_mm': []})
        except Exception as e:
            print(f"讀取標註檔案時發生錯誤：{e}。Mask 將返回全零。")
            self.annotations = pd.DataFrame({'seriesuid': [], 'coordX': [], 'coordY': [], 'coordZ': [], 'diameter_mm': []})

    def __len__(self):
        return len(self.full_paths)

    def __getitem__(self, idx):
        file_path = self.full_paths[idx]
        
        # 從檔案名稱中提取 seriesuid
        file_name = os.path.basename(file_path)
        print(f'✅️ file_name = {file_name}')
        series_uid = os.path.splitext(file_name)[0]
        print(f"✅️ Processing series_uid: {series_uid}")
        
        # 1. 讀取 CT 影像
        itk_image = sitk.ReadImage(file_path)
        
        # 轉換為 PyTorch Tensor (Z, Y, X)
        ct = sitk.GetArrayFromImage(itk_image)
        ct_n = sanitize_hu(ct)
        numpy_array = window_to_uint8(ct_n) 

        # 2. 生成 Mask
        mask_array = get_mask(itk_image, self.annotations, series_uid)

        # 3. 肺切割 ---
        z_len = numpy_array.shape[0]
        H, W = numpy_array.shape[1], numpy_array.shape[2]
        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            slices_to_segment = [numpy_array[z_idx] for z_idx in range(z_len)]
            segmented_slices_list = list(executor.map(segment_single_slice, slices_to_segment))

        lung_images = np.stack(segmented_slices_list, axis=0)

        # 檢查輸出形狀
        assert lung_images.shape == numpy_array.shape, "Segmented images shape mismatch!"
        assert numpy_array.shape == mask_array.shape, "Image and Mask shape mismatch!"
        
        #轉成pytorch訓練的tensor
        image_tensor = torch.from_numpy(lung_images.astype(np.float32))
        mask_tensor = torch.from_numpy(mask_array.astype(np.float32))

        return {
            'image': image_tensor,          # Shape: (Z, 512, 512), dtype=torch.float32
            'mask': mask_tensor,            # Shape: (Z, 512, 512), dtype=torch.float32 (0.0 或 1.0)
            'series_uid': series_uid,
        }



In [8]:
LUNA16_DATA_DIR = r"D:\Daniel\LDCT\luna16-yolo\dataset_luna\subset0"
LUNA16_ANNOTATIONS_CSV = r"D:\Daniel\LDCT\luna16-yolo\CSVFILES\annotations.csv" 

In [9]:
# 建立 Dataset 實例
luna_dataset = Luna16MaskDataset(
    data_dir=LUNA16_DATA_DIR,
    annotations_file=LUNA16_ANNOTATIONS_CSV
)

print(f"找到 {len(luna_dataset)} 個 CT 掃描檔案。")

找到 226 個 CT 掃描檔案。


In [10]:
import torch
import matplotlib.pyplot as plt
import numpy as np
# 確保 luna_dataloader, luna_dataset, debug_print, debug_plot 已定義

# 建立 DataLoader (batch_size=1 避免不同 Z 維度導致的錯誤)
luna_dataloader = DataLoader(luna_dataset, batch_size=1, shuffle=False)
'''
# 迭代 DataLoader 進行檢查
for i, data in enumerate(luna_dataloader):
    image = data['image'][0]
    mask = data['mask'][0]
    series_uid = data['series_uid'][0]

    if dir_plot:
        for z in range(image.shape[0]):
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(image[z, :, :].cpu().numpy(), cmap='gray')
            plt.title(f"Image Slice {z}: {series_uid}")
            plt.subplot(1, 2, 2)
            plt.imshow(mask[z, :, :].cpu().numpy(), cmap='gray')
            plt.title(f"Mask Slice {z}: {series_uid}")
            if not os.path.exists("debug_output_img"):
                os.makedirs("debug_output_img")
            plt.savefig(f"debug_output_img/patch_and_mask_slice_{series_uid}_{z}.png") # 替換 plt.show()，將圖形儲存為檔案
            plt.close() # 關閉圖形，避免它在執行環境中顯示


    Z, H, W = image.shape
    
    if debug_print:
        print(f"--- 處理第 {i+1} 個檔案 ---")
        print(f"Series UID: {series_uid}")
        print(f"Image Tensor 形狀: {image.shape}")
        print(f"Mask Tensor 形狀: {mask.shape}")
        print(f"Mask 內部包含的標註體素數量: {torch.sum(mask).item()}")
    
    if debug_plot:
        # 修正 1: 動態計算中間切片索引
        mid_slice = Z // 2 # 取 Z 軸的中央切片
        
        # 為了更精確檢查 Mask 效果，我們應該找到有結節的切片
        nodule_slices = (torch.sum(mask, dim=(1, 2)) > 0.5).nonzero(as_tuple=False)
        
        if len(nodule_slices) > 0:
            # 優先顯示第一個有結節的切片
            display_slice_idx = nodule_slices[0].item()
            slice_description = f"Nodule Slice ({display_slice_idx})"
        else:
            # 如果沒有結節，則顯示中間切片
            display_slice_idx = mid_slice
            slice_description = f"Middle Slice ({display_slice_idx})"

        # 修正 2: 獨立繪圖
        plt.figure(figsize=(10, 5))
        
        # 繪製 Image
        plt.subplot(1, 2, 1)
        # 使用 .numpy() 轉換，並指定灰度圖
        plt.imshow(image[display_slice_idx, :, :].cpu().numpy(), cmap='gray')
        plt.title(f"Image ({slice_description}): {series_uid}")
        plt.axis('off')

        # 繪製 Mask
        plt.subplot(1, 2, 2)
        # ⚠️ 建議將 Mask 繪製為與 Image 重疊的透明圖層，以確認位置和外型
        mask_np = mask[display_slice_idx, :, :].cpu().numpy()
        
        # 繪製 Image (背景)
        #plt.imshow(image[display_slice_idx, :, :].cpu().numpy(), cmap='gray')
        
        # 繪製 Mask (前景，使用紅色或黃色疊加)
        # Mask 使用 alpha 通道實現透明度
        plt.imshow(mask_np, cmap='gray')
        
        plt.title(f"Mask Overlay ({slice_description})")
        plt.axis('off')
        if dir_plot : 
            if not os.path.exists("debug_nodule"):
                os.makedirs("debug_nodule")
            plt.savefig(f"debug_nodule/patch_and_mask_slice_{slice_description}.png") # 替換 plt.show()，將圖形儲存為檔案
        plt.close() # 關閉圖形，避免它在執行環境中顯示
        

    
    # 檢查形狀是否符合 (Z, 512, 512)
    is_correct_shape = (image.dim() == 3) and (image.shape[1] == 512) and (image.shape[2] == 512)
    print(f"空間形狀是否符合 (Z, 512, 512): {is_correct_shape}")
'''
print("\n資料讀取、Mask 生成與轉換範例完成。")


資料讀取、Mask 生成與轉換範例完成。


### 製作資料集 (3, 64, 64)

In [11]:
import torch
import numpy as np
import random
import torchvision.transforms.functional as F_T
import torch.nn.functional as F

# 設定參數
PATCH_SIZE = 64
Z_DEPTH = 3 
OUTPUT_MASK_DEPTH = 1 # 只預測中間切片 (Z=1) 的 Mask

# -----------------------------------------------------------
# 1. Patch 提取與平衡採樣主函式 (Segmentation專用)
# -----------------------------------------------------------
def extract_balanced_segmentation_patches(image: torch.Tensor, mask: torch.Tensor, max_patches_per_scan: int = 40) -> tuple:
    """
    從單個 CT 掃描中提取平衡的正負樣本 Patch (3, 64, 64) 及其對應的 Mask (1, 64, 64)。
    
    :param image: 單個 CT 掃描影像 Tensor (已分割肺部, 0-255)，shape (Z, 512, 512)。
    :param mask: 單個 Mask Tensor (0/1)，shape (Z, 512, 512)。
    :param max_patches_per_scan: 每個掃描最多提取的 Patch 總數 (N_pos + N_neg)。
    :return: 包含平衡 Patch 和對應 Mask 的元組 (patches, masks)。
             patches shape: (N, 3, 64, 64); masks shape: (N, 1, 64, 64)
    """
    Z, H, W = image.shape
    
    # 查找結節中心點 (用於正樣本採樣)
    positive_indices = (mask > 0).nonzero(as_tuple=False) 

    extracted_patches = []
    extracted_masks = []
    
    # ---------------------------------------------------
    # I. 正樣本採樣 (Positive Patch Sampling)
    # ---------------------------------------------------
    # 目標：正樣本數量為 max_patches_per_scan // 2
    num_positive_target = max_patches_per_scan // 2
    
    # 遍歷潛在的正樣本體素中心
    if len(positive_indices) > 0:
        # 隨機選擇一部分中心點作為初始採樣點
        sample_indices = random.sample(positive_indices.tolist(), min(len(positive_indices), num_positive_target * 2))
        
        for z_c, y_c, x_c in sample_indices:
            if z_c < 1 or z_c >= Z - 1: # 確保 Z-1 和 Z+1 存在
                continue

            # 隨機偏移中心點
            offset_y = random.randint(-PATCH_SIZE // 4, PATCH_SIZE // 4)
            offset_x = random.randint(-PATCH_SIZE // 4, PATCH_SIZE // 4)
            y_start = y_c - PATCH_SIZE // 2 + offset_y
            x_start = x_c - PATCH_SIZE // 2 + offset_x

            # 邊界檢查與裁剪
            y_start = np.clip(y_start, 0, H - PATCH_SIZE).item()
            x_start = np.clip(x_start, 0, W - PATCH_SIZE).item()
            
            # 裁剪 Image Patch (3, 64, 64)
            patch_image = image[z_c - 1 : z_c + 2, 
                                y_start : y_start + PATCH_SIZE, 
                                x_start : x_start + PATCH_SIZE].float()
            
            # 裁剪 Mask Patch (只取中間切片: 1, 64, 64)
            # unsqueeze(0) 將 (64, 64) 變為 (1, 64, 64)
            patch_mask = mask[z_c, y_start : y_start + PATCH_SIZE, 
                             x_start : x_start + PATCH_SIZE].unsqueeze(0).float()
            
            # 確保該 Patch 確實包含結節體素 (防止邊界誤採樣)
            if torch.sum(patch_mask) > 0:
                 extracted_patches.append(patch_image)
                 extracted_masks.append(patch_mask)
                 if len(extracted_patches) >= num_positive_target:
                    break
    
    # 確保我們有原始正樣本可供增強
    original_positive_patches = extracted_patches.copy()
    original_positive_masks = extracted_masks.copy()

    # ---------------------------------------------------
    # II. 負樣本採樣 (Negative Patch Sampling)
    # ---------------------------------------------------
    num_neg_target = max_patches_per_scan - len(original_positive_patches) # 負樣本數量是總目標減去已有的正樣本
    negative_patches = []
    negative_masks = []

    # 設定檢查閾值
    # 假設 Patch 總像素是 3*64*64 = 12288
    TOTAL_PIXELS = Z_DEPTH * PATCH_SIZE * PATCH_SIZE 
    # 要求 Patch 中至少要有 20% 的非零像素 (即有效組織)
    MIN_NON_ZERO_RATIO = 0.2 
    MIN_NON_ZERO_COUNT = int(TOTAL_PIXELS * MIN_NON_ZERO_RATIO) 

    # 嘗試採樣負樣本，直到達到所需的數量
    # 增加嘗試次數的上限，以防很難找到合格的 Patch
    MAX_ATTEMPTS = num_neg_target * 20 
    attempts = 0

    while len(negative_patches) < num_neg_target and attempts < MAX_ATTEMPTS:
        attempts += 1
        
        # 隨機採樣 Patch 中心和起始點
        z_c = random.randint(1, Z - 2) 
        y_start = random.randint(0, H - PATCH_SIZE)
        x_start = random.randint(0, W - PATCH_SIZE)
        
        # 裁剪 Image Patch (3, 64, 64)
        patch_image = image[z_c - 1 : z_c + 2, 
                            y_start : y_start + PATCH_SIZE, 
                            x_start : x_start + PATCH_SIZE].float()
        
        # 裁剪 Mask Patch (用於檢查是否為負樣本)
        patch_mask = mask[z_c, y_start : y_start + PATCH_SIZE, 
                        x_start : x_start + PATCH_SIZE].unsqueeze(0).float()
        
        # 條件 1: 必須是負樣本 (Mask 區域內沒有任何結節像素)
        is_negative = (torch.sum(patch_mask) == 0)
        
        # 條件 2: 硬負樣本檢查 - 必須包含足夠多的非零像素 (非背景)
        # 檢查 Patch 影像中，像素值 > 0 的數量
        non_zero_count = torch.sum(patch_image > 0).item()
        is_hard_negative = (non_zero_count >= MIN_NON_ZERO_COUNT)
        
        if is_negative and is_hard_negative:
            negative_patches.append(patch_image)
            negative_masks.append(patch_mask) # 負樣本的 Mask 是一個全零的 (1, 64, 64)

    # 提示: 如果嘗試次數達到上限，但未達到目標數量，則退出迴圈，使用現有數量。
    if attempts == MAX_ATTEMPTS:
        print(f"⚠️ 警告: 在 {series_uid} 中，只找到 {len(negative_patches)} 個合格的硬負樣本。")

    # ---------------------------------------------------
    # III. 增強正樣本並與負樣本數量平衡
    # ---------------------------------------------------
    
    # 確定最終的 Patch 列表和 Mask 列表
    final_patches = negative_patches + original_positive_patches
    final_masks = negative_masks + original_positive_masks
    
    # 計算還需要多少個增強的正樣本 Patch 才能達到平衡 (這裡定義平衡為 N_pos == N_neg)
    num_pos_original = len(original_positive_patches)
    num_neg_current = len(negative_patches)

    target_pos_count = num_neg_current
    augmentation_needed = max(0, target_pos_count - num_pos_original)
    
    if num_pos_original > 0 and augmentation_needed > 0:
        for i in range(augmentation_needed):
            # 從原始正樣本中隨機選擇一個 Patch 和 Mask
            idx = random.randint(0, num_pos_original - 1)
            original_patch = original_positive_patches[idx]
            original_mask = original_positive_masks[idx]
            
            # --- 執行同步幾何增強 ---
            
            # 1. 隨機旋轉
            angle = random.uniform(-15, 15)
            # 對 Image (3, 64, 64) 和 Mask (1, 64, 64) 應用相同的旋轉
            # F_T.rotate 自動處理了 C, H, W
            augmented_patch = F_T.rotate(original_patch, angle, interpolation=F_T.InterpolationMode.BILINEAR)
            augmented_mask = F_T.rotate(original_mask, angle, interpolation=F_T.InterpolationMode.NEAREST) # Mask 必須使用 NEAREST 插值！
            
            # 2. 隨機翻轉
            if random.random() < 0.5:
                augmented_patch = F_T.hflip(augmented_patch)
                augmented_mask = F_T.hflip(augmented_mask)
            if random.random() < 0.5:
                augmented_patch = F_T.vflip(augmented_patch)
                augmented_mask = F_T.vflip(augmented_mask)
            
            # 將增強後的數據加入最終列表
            final_patches.append(augmented_patch)
            final_masks.append(augmented_mask)

    # ---------------------------------------------------
    # IV. 合併與輸出
    # ---------------------------------------------------
    if final_patches:
        patches_tensor = torch.stack(final_patches, dim=0) # Shape: (N, 3, 64, 64)
        masks_tensor = torch.stack(final_masks, dim=0)     # Shape: (N, 1, 64, 64)
        
        # 確保 Mask 仍然是二元的 (NEAREST插值後可能出現浮點數)
        masks_tensor = (masks_tensor > 0.5).float() 
        
        # 隨機打亂順序
        permutation = torch.randperm(patches_tensor.size(0))
        patches_tensor = patches_tensor[permutation]
        masks_tensor = masks_tensor[permutation]
        
        return patches_tensor, masks_tensor
    else:
        return torch.empty((0, Z_DEPTH, PATCH_SIZE, PATCH_SIZE), dtype=torch.float32), \
               torch.empty((0, OUTPUT_MASK_DEPTH, PATCH_SIZE, PATCH_SIZE), dtype=torch.float32)



In [12]:
import torch
import numpy as np
from multiprocessing import Pool, cpu_count
import itertools
# 假設 extract_balanced_segmentation_patches 已經在頂層定義
# from your_module import extract_balanced_segmentation_patches 

# -----------------------------------------------------------
# 1. 多行程工作函式 (必須在程式碼頂層定義)
# -----------------------------------------------------------
def process_single_scan_task(task):
    """
    由多行程 Pool 執行的工作函式。
    負責執行 extract_balanced_segmentation_patches 並返回結果。
    
    :param task: 包含 (image_np, mask_np, series_uid, max_patches_per_scan) 的元組。
    :return: 包含 (patches_tensor, masks_tensor, series_uid, num_patches, num_pos_patches) 的元組。
    """
    # 將輸入元組解包
    image_np, mask_np, series_uid, max_patches_per_scan = task

    # 註：這裡假設 extract_balanced_segmentation_patches 可以處理 NumPy 陣列輸入，
    # 且返回 PyTorch Tensor 輸出 (patches, masks)。
    # 如果您的提取函式只接受 Tensor，請在呼叫前加上：
    # image_tensor = torch.from_numpy(image_np)
    # mask_tensor = torch.from_numpy(mask_np)

    try:
        # ❗️ 調用目標 Patch 提取函式
        patches, masks = extract_balanced_segmentation_patches(
            image=image_np, 
            mask=mask_np, 
            max_patches_per_scan=max_patches_per_scan 
        )
    except Exception as e:
        # 處理任何程序中的錯誤
        print(f"子程序處理 {series_uid} 時發生錯誤: {e}")
        patches = torch.empty(0, 3, 64, 64)
        masks = torch.empty(0, 1, 64, 64)

    num_patches = patches.shape[0]
    num_pos_patches = 0
    
    if num_patches > 0:
        # 驗證形狀並計算統計數據
        assert patches.shape[1:] == (3, 64, 64), f"Patch Shape Error for {series_uid}!"
        assert masks.shape[1:] == (1, 64, 64), f"Mask Shape Error for {series_uid}!"
        
        # 統計資訊：現在我們統計的是包含結節的 Patch 數量
        num_pos_patches = (torch.sum(torch.sum(masks, dim=(1, 2, 3)) > 0.5)).item()
        print(f"提取 Segmentation Patch: {series_uid}")
        print(f"  -> 成功提取 {num_patches} 個 Patch ({num_pos_patches} 個包含結節)")
    else:
        print(f"提取 Segmentation Patch: {series_uid}")
        print("  -> 未提取到 Patch。")
        
    # 返回所有結果，供主程序合併
    return patches, masks, series_uid, num_patches, num_pos_patches

# -----------------------------------------------------------
# 2. 主程序執行邏輯 (進行平行化)
# -----------------------------------------------------------

# 假設 luna_dataloader 已經定義
MAX_PATCHES_PER_SCAN = 20
tasks = []

# --- 階段 1: 載入 DataLoader 資料並準備任務列表 (在主程序中執行) ---
print("--- 階段 1: 載入 DataLoader 資料 ---")
for i, data in enumerate(luna_dataloader):
    # 從 DataLoader 取得 PyTorch Tensor (假設是 (1, Z, 512, 512) 的批次)
    image = data['image'][0]
    mask = data['mask'][0]
    series_uid = data['series_uid'][0]
    
    # 關鍵：將 PyTorch Tensor 轉換為 CPU NumPy 陣列 (以便跨程序安全傳輸)
    # 如果 data['image'] 是 CUDA Tensor, 需要先 .cpu()
    image_np = image.cpu().numpy() if image.is_cuda else image.numpy()
    mask_np = mask.cpu().numpy() if mask.is_cuda else mask.numpy()
    
    # 將所有參數包裝成一個任務元組
    tasks.append((image_np, mask_np, series_uid, MAX_PATCHES_PER_SCAN))
    
print(f"總共收集了 {len(tasks)} 個 CT 掃描任務。")

# --- 階段 2: 使用 Multiprocessing Pool 進行平行處理 ---
print("\n--- 階段 2: 啟動多行程 Patch 提取 ---")
num_processes = cpu_count() - 1 if cpu_count() > 1 else 1 # 使用所有核心數-1
if len(tasks) < num_processes:
    num_processes = len(tasks)
    
print(f"使用 {num_processes} 個程序進行 Patch 提取...")

all_results = []
with Pool(processes=num_processes) as pool:
    # pool.map 會將 tasks 列表中的每個元素作為 process_single_scan_task 的單一輸入參數
    all_results = pool.map(process_single_scan_task, tasks)

# --- 階段 3: 彙整所有 Patch ---
print("\n--- 階段 3: 彙整所有 Patch ---")
all_patches_list = []
all_masks_list = []
total_patches = 0
total_pos_patches = 0

for patches, masks, series_uid, num_patches, num_pos_patches in all_results:
    if num_patches > 0:
        all_patches_list.append(patches)
        all_masks_list.append(masks)
        total_patches += num_patches
        total_pos_patches += num_pos_patches
        
# 最終合併所有 CT 掃描的 Patch
if all_patches_list:
    final_train_patches = torch.cat(all_patches_list, dim=0) # (N_total, 3, 64, 64)
    final_train_masks = torch.cat(all_masks_list, dim=0) # (N_total, 1, 64, 64)
    
    print("\n--- 最終 Segmentation Patch 數據統計 ---")
    print(f"總共提取的 Patch 數量: {total_patches}")
    print(f"其中包含結節的 Patch 數量: {total_pos_patches}")
    print(f"不含結節的 Patch 數量: {total_patches - total_pos_patches}")
    
else:
    print("\n未提取到任何 Patch 數據。請檢查資料集和標註。")

--- 階段 1: 載入 DataLoader 資料 ---
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860


  l = convex_hull_image(l)


✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100332161840553388986847034053.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100332161840553388986847034053
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100398138793540579077826395208.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100398138793540579077826395208
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100530488926682752765845212286.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100530488926682752765845212286
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100620385482151095585000946543.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100620385482151095585000946543
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.100684836163890911914061745866.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.145

  r = convex_hull_image(r)


✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.109882169963817627559804568094.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.109882169963817627559804568094
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.110678335949765929063942738609.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.110678335949765929063942738609
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.111017101339429664883879536171.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.111017101339429664883879536171
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.111172165674661221381920536987.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.14519.5.2.1.6279.6001.111172165674661221381920536987
✅️ file_name = 1.3.6.1.4.1.14519.5.2.1.6279.6001.111258527162678142285870245028.mhd
✅️ Processing series_uid: 1.3.6.1.4.1.145

MemoryError: Unable to allocate 484. MiB for an array with shape (484, 512, 512) and data type float32

In [None]:
if debug_print:
    nslice = 0
    patch_chw = final_train_patches[nslice] #(N_total,3,64,64) -> (3,64,64)
    mask_chw = final_train_masks[nslice]   #(N_total,1,64,64) -> (1,64,64)
    # 2. 將 PyTorch Tensor (C, H, W) 轉換為 (H, W, C) 
    # 轉換維度順序：(0, 1, 2) -> (1, 2, 0)
    patch_hwc = patch_chw.permute(1, 2, 0)
    mask_hwc = mask_chw.permute(1, 2, 0)

    patch_display = patch_hwc.cpu().numpy()
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow(patch_display.astype('uint8')) 
    plt.subplot(1,2,2)
    plt.imshow(mask_hwc, cmap='gray')

    plt.title(f"Patch Shape: {patch_display.shape}")
    plt.show()

if debug_plot : 
    for nslice in range(final_train_patches.shape[0]):
        
        patch_chw = final_train_patches[nslice] #(N_total,3,64,64) -> (3,64,64)
        mask_chw = final_train_masks[nslice] #(N_total,1,64,64) -> (1,64,64)
        # 2. 將 PyTorch Tensor (C, H, W) 轉換為 (H, W, C) 
        # 轉換維度順序：(0, 1, 2) -> (1, 2, 0)
        patch_hwc = patch_chw.permute(1, 2, 0)
        mask_hwc = mask_chw.permute(1, 2, 0)

        patch_display = patch_hwc.cpu().numpy()
        plt.figure(figsize=(10,5))
        plt.subplot(1,2,1)
        plt.imshow(patch_display.astype('uint8')) 
        plt.subplot(1,2,2)
        plt.imshow(mask_hwc.cpu().numpy().squeeze(), cmap='gray') # 註：為確保單通道(1,64,64) mask能正確顯示，最好加上 .squeeze()

        plt.title(f"Patch Shape: {patch_display.shape}")
        if dir_plot :
            os.mkdir("debug_output") if not os.path.exists("debug_output") else None
            plt.savefig(f"debug_output/patch_and_mask_slice_{nslice}.png") # 替換 plt.show()，將圖形儲存為檔案
        plt.close() # 關閉圖形，避免它在執行環境中顯示


### 下載pth

In [None]:
import torch
import os

# 設置儲存路徑和檔案名
SAVE_DIR = r"D:\Daniel\for_git\LDCT_git\TSCNN\first_stage_code\Data_processing"
FILE_NAME = "Train0.pth"
SAVE_PATH = os.path.join(SAVE_DIR, FILE_NAME)

# 確保儲存目錄存在
os.makedirs(SAVE_DIR, exist_ok=True)
data_to_save = {
    'patches': final_train_patches,# final_train_patches.shape: (N_total, 3, 64, 64)
    'masks': final_train_masks,# final_train_labels.shape: (N_total, 1, 64, 64)
    'description': 'LUNA16 Balanced 3D Patches (3x64x64, 0-255 scaled)',
    'total_count': final_train_patches.shape[0]
}

# 執行儲存
torch.save(data_to_save, SAVE_PATH)

print(f"✅ 數據已成功儲存到: {SAVE_PATH}")

✅ 數據已成功儲存到: D:\Daniel\for_git\LDCT_git\TSCNN\first_stage_code\Data_processing\VAL.pth


In [None]:
import torch

LOAD_PATH = r"D:\Daniel\for_git\LDCT_git\TSCNN\first_stage_code\Data_processing\Train0.pth"

# 執行加載
loaded_data = torch.load(LOAD_PATH)

# 取出 Patch 和 Label
loaded_patches = loaded_data['patches']
loaded_labels = loaded_data['masks']

print(f"✅ 數據已成功加載。")
print(f"加載的 Patch 總數: {loaded_patches.shape[0]}")
# 您現在可以直接將 loaded_patches 和 loaded_labels 用於訓練模型！

✅ 數據已成功加載。
加載的 Patch 總數: 30


  loaded_data = torch.load(LOAD_PATH)
