In [None]:
import os
import torch
import numpy as np
import nibabel as nib  # 處理 .nii 檔案的標準庫
from torch.utils.data import Dataset

class CRNNDataset(Dataset):
    def __init__(self, root_dir, target_dir, n_frames=3, is_train=True):
        """
        Args:
            root_dir (str): 存放 Training Data (原圖 .nii) 的資料夾
            target_dir (str): 存放 Label Data (標註 .nii) 的資料夾
            n_frames (int): 輸入模型的連續幀數 (e.g. 3)
        """
        self.root_dir = root_dir
        self.target_dir = target_dir
        self.n_frames = n_frames
        self.is_train = is_train
        
        # 建立索引：這是關鍵，我們不直接存圖片，而是存「去哪裡找這張圖」的清單
        self.samples = self._make_dataset()

    def _make_dataset(self):
        """
        掃描資料夾，讀取每個 NIfTI 的 header (不讀內容)，
        計算有多少個切片，並建立索引表。
        """
        samples = []
        
        # 1. 搜尋所有 .nii 或 .nii.gz 檔案
        if not os.path.exists(self.root_dir):
             raise ValueError(f"Directory not found: {self.root_dir}")
             
        file_names = sorted([f for f in os.listdir(self.root_dir) 
                             if f.endswith(('.nii', '.nii.gz'))])
        
        print(f"Found {len(file_names)} NIfTI files in {self.root_dir}")

        for fname in file_names:
            img_path = os.path.join(self.root_dir, fname)
            lbl_path = os.path.join(self.target_dir, fname) # 假設檔名完全一致
            
            # 檢查 Label 是否存在
            if not os.path.exists(lbl_path):
                print(f"[Warning] Label not found for {fname}, skipping...")
                continue
            
            # 2. 讀取 Header 獲取切片數量 (Depth)
            # nib.load 不會把整個檔案讀進 RAM，只讀 Header，非常快
            proxy_img = nib.load(img_path)
            n_slices = proxy_img.shape[2]
            
            # 3. 製作滑動視窗索引
            # 假設 n_frames=3，我們從第 0 張開始，直到 n_slices - n_frames + 1
            if n_slices >= self.n_frames:
                for i in range(n_slices - self.n_frames + 1):
                    samples.append({
                        'img_path': img_path,
                        'lbl_path': lbl_path,
                        'start_idx': i  # 這是視窗的第一幀索引
                    })
        
        print(f"Total samples (slices) created: {len(samples)}")
        return samples

    def _load_nii_slice(self, path, slice_idx):
        """
        讀取單一 NIfTI 檔案中的特定切片。
        使用 dataobj 進行 Disk-IO 讀取，避免載入整個 3D 體積。
        """
        img_obj = nib.load(path)
        
        # dataobj 是 ArrayProxy，支援 slicing，只讀取需要的這層
        # 這裡還沒有做 Normalization，拿到的是原始 HU 值
        slice_data = img_obj.dataobj[..., slice_idx]
        
        # 轉為 float32 以便後續運算
        slice_data = np.array(slice_data, dtype=np.float32)
        
        # NIfTI 讀出來通常這時候還不需要轉置，視你的模型習慣
        # PyTorch 習慣 (C, H, W)，這裡我們先回傳 (H, W)
        return slice_data

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_info = self.samples[idx]
        img_path = sample_info['img_path']
        lbl_path = sample_info['lbl_path']
        start_idx = sample_info['start_idx']
        
        # --- 1. Load Input Volumes (連續 n_frames 張) ---
        tensors = []
        for i in range(self.n_frames):
            curr_slice_idx = start_idx + i
            
            # 讀取原始 HU 值 (下一步驟我們再來加 Windowing & CLAHE)
            img_slice = self._load_nii_slice(img_path, curr_slice_idx)
            
            # [暫時] 先簡單轉成 Tensor 讓我們能測試跑通
            # 之後這裡會插入 _preprocess(img_slice)
            tensor = torch.from_numpy(img_slice).unsqueeze(0) # (1, H, W)
            tensors.append(tensor)
            
        # 整理 input (prev, curr, next)
        sample = {}
        if self.n_frames == 3:
            sample = {'prev': tensors[0], 'curr': tensors[1], 'next': tensors[2]}
        elif self.n_frames == 2:
            sample = {'prev': tensors[0], 'curr': tensors[1]}

        # --- 2. Load Label (只取中間那幀) ---
        # 如果是 3 幀，Label 對應中間 (index + 1)
        # 如果是 2 幀，Label 對應後面 (index + 1)
        lbl_slice_idx = start_idx + 1
        
        raw_mask = self._load_nii_slice(lbl_path, lbl_slice_idx)
        
        # === 核心邏輯：只取 Label == 5 ===
        # 製作 Binary Mask: 5 -> 1.0, 其他 -> 0.0
        mask = (raw_mask == 5).astype(np.float32)
        
        # 轉 Tensor: (1, H, W)
        sample['label'] = torch.from_numpy(mask).unsqueeze(0)
            
        return sample

    def _preprocess(self, img_np):
            """
            輸入: Raw Slice (float32), 包含原始 HU 值
            輸出: Tensor (8, H, W)
            """
            # --- Stage 1: Physics Layer (Windowing) ---
            # Neck CT 最佳化: Center 50, Width 350 -> [-125, 225]
            min_hu, max_hu = -125, 225
            img = np.clip(img_np, min_hu, max_hu)

            # --- Stage 2: Logic Layer (Normalization & CLAHE) ---
            img = (img - min_hu) / (max_hu - min_hu) # 0.0 ~ 1.0
            
            # 轉 uint8 給 CLAHE
            img_uint8 = (img * 255).astype(np.uint8)
            
            # CLAHE (ClipLimit=2.0, Grid=8x8)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            img_clahe = clahe.apply(img_uint8)
            
            # Base Image (Channel 0) - 這是 Pre-blur 的源頭
            base_img = img_clahe.astype(np.float32) / 255.0
            
            features = [base_img] # [0]
            
            # --- Stage 3: Feature Layer (Extraction) ---
            
            # 3.1 Gaussian Blur (Context Channel)
            # 修改參數: Kernel (7, 7), Sigma 1.1
            blur_img = cv2.GaussianBlur(base_img, (7, 7), 1.1)
            features.append(blur_img) # [1] Channel 1
            
            if self.use_grad:
                ksize = 3 
                
                # --- Stream A: Post-Blur Gradients (Clean) ---
                # 針對模糊後的圖算梯度，抓大輪廓
                sobelx_clean = cv2.Sobel(blur_img, cv2.CV_32F, 1, 0, ksize=ksize)
                sobely_clean = cv2.Sobel(blur_img, cv2.CV_32F, 0, 1, ksize=ksize)
                lap_clean = cv2.Laplacian(blur_img, cv2.CV_32F, ksize=ksize)
                
                features.extend([
                    self._normalize_minmax(np.abs(sobelx_clean)), # [2]
                    self._normalize_minmax(np.abs(sobely_clean)), # [3]
                    self._normalize_minmax(np.abs(lap_clean))     # [4]
                ])
                
                # --- Stream B: Pre-Blur Gradients (Noisy/Detailed) ---
                # 針對 CLAHE 原圖算梯度，抓細節 (包含雜訊)
                # 這是你要求的「多給我」的部分
                sobelx_raw = cv2.Sobel(base_img, cv2.CV_32F, 1, 0, ksize=ksize)
                sobely_raw = cv2.Sobel(base_img, cv2.CV_32F, 0, 1, ksize=ksize)
                lap_raw = cv2.Laplacian(base_img, cv2.CV_32F, ksize=ksize)
                
                features.extend([
                    self._normalize_minmax(np.abs(sobelx_raw)),   # [5]
                    self._normalize_minmax(np.abs(sobely_raw)),   # [6]
                    self._normalize_minmax(np.abs(lap_raw))       # [7]
                ])
                
            # --- Stage 4: Stacking ---
            # 堆疊所有通道 -> (8, H, W)
            final_img = np.stack(features, axis=0)
            
            return torch.from_numpy(final_img).float()

ModuleNotFoundError: No module named 'torch'