In [2]:
# %% Imports
import os
import glob
import time
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Use standard tqdm if not in a notebook environment
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
import random
import scipy.ndimage
from skimage.measure import label as skimage_label, regionprops
from skimage.morphology import disk, binary_closing # Removed convex_hull_image, roberts as not used
from skimage.segmentation import clear_border
import scipy.ndimage as ndi

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, precision_recall_curve, auc, f1_score,
                             precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay)

# %% Configuration
# --- MODIFY THESE PATHS ---
DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Stages" # Base directory of DSB 2017 Stage 1 scans (containing patient folders)
DSB_LABELS_CSV = r'C:\Users\rouaa\Documents\Final_Pneumatect\stage1_labels.csv' # Path to DSB patient cancer labels CSV
PREPROCESSED_DSB_PATH = './preprocessed_dsb_non_cancer/' # <<-- MODIFIED: Separate dir for non-cancer scans
# --- MODIFIED: New output path for the limited dataset ---
PREPROCESSED_DSB_PATH = './preprocessed_dsb_50_each/'
# ---

# Preprocessing & Model Params
TARGET_SPACING = [1.5, 1.5, 1.5]
FINAL_SCAN_SIZE = (96, 128, 128)
CLIP_BOUND_HU = [-1000.0, 400.0]
PIXEL_MEAN = 0.25

# Training Params
NUM_CLASSES = 1
BATCH_SIZE = 4
LEARNING_RATE = 0.0001
EPOCHS = 150
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# <<< --- ADDED: Scan limit per class --- >>>
SCAN_LIMIT_PER_CLASS = 50
# <<< --- END ADDED --- >>>


# Ensure output directory exists
os.makedirs(PREPROCESSED_DSB_PATH, exist_ok=True) # Use new path

# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True # Ensure reproducibility if using CuDNN
    torch.backends.cudnn.benchmark = False


# %% --- Data Loading and Selection ---

print(f"--- Loading Data and Selecting up to {SCAN_LIMIT_PER_CLASS} Scans Per Class ---")

# --- Verify Paths ---
if not os.path.isdir(DSB_PATH): raise SystemExit(f"ERROR: DSB Scans path not found: {DSB_PATH}")
if not os.path.isfile(DSB_LABELS_CSV): raise SystemExit(f"ERROR: DSB Labels CSV not found: {DSB_LABELS_CSV}")
print(f"DSB Scans path: {DSB_PATH}")
print(f"DSB Labels CSV: {DSB_LABELS_CSV}")

# --- Load Labels ---
try:
    dsb_labels_df = pd.read_csv(DSB_LABELS_CSV)
    dsb_labels_df = dsb_labels_df.rename(columns={'id': 'patient_id'})
    print(f"Loaded {len(dsb_labels_df)} total DSB patient labels.")
    if 'cancer' not in dsb_labels_df.columns: raise ValueError("Labels CSV needs 'cancer' column.")
    print("Original label distribution:\n", dsb_labels_df['cancer'].value_counts())
    # Create a lookup dictionary for labels
    patient_labels_all = dsb_labels_df.set_index('patient_id')['cancer'].to_dict()
except Exception as e: raise SystemExit(f"ERROR: Failed to load labels CSV: {e}")

# --- Check Scan Folders ---
scan_folders = [f for f in os.listdir(DSB_PATH) if os.path.isdir(os.path.join(DSB_PATH, f))]
print(f"Found {len(scan_folders)} potential patient scan folders.")
found_scan_ids = set(scan_folders)

# --- Find Common IDs (patients with both label and scan folder) ---
labeled_patient_ids_all = set(dsb_labels_df['patient_id'])
common_ids_all = labeled_patient_ids_all.intersection(found_scan_ids)
print(f"Found {len(common_ids_all)} patient IDs with both labels and scan folders.")
if not common_ids_all: raise SystemExit("No matching patient IDs found. Cannot continue.")

# --- Separate Common IDs by Class ---
common_ids_cancer = []
common_ids_non_cancer = []
for pid in common_ids_all:
    label = patient_labels_all.get(pid)
    if label == 1:
        common_ids_cancer.append(pid)
    elif label == 0:
        common_ids_non_cancer.append(pid)

print(f"Available Cancerous scans with labels: {len(common_ids_cancer)}")
print(f"Available Non-Cancerous scans with labels: {len(common_ids_non_cancer)}")

# --- Shuffle and Select Limited Number Per Class ---
random.shuffle(common_ids_cancer) # Shuffle in-place
random.shuffle(common_ids_non_cancer)

selected_cancer_ids = common_ids_cancer[:SCAN_LIMIT_PER_CLASS]
selected_non_cancer_ids = common_ids_non_cancer[:SCAN_LIMIT_PER_CLASS]

print(f"Selected {len(selected_cancer_ids)} Cancerous scans.")
print(f"Selected {len(selected_non_cancer_ids)} Non-Cancerous scans.")

# --- Combine selected IDs for processing ---
scans_to_process = selected_cancer_ids + selected_non_cancer_ids
random.shuffle(scans_to_process) # Shuffle the combined list

print(f"Total scans selected for preprocessing: {len(scans_to_process)}")

# Use the subset of labels corresponding to the selected scans
patient_labels = {pid: patient_labels_all[pid] for pid in scans_to_process}


# %% --- Preprocessing Functions (Unchanged) ---
# Functions: load_scan_series, resample, get_segmented_lungs,
#            normalize_hu, zero_center, resize_scan_to_target
# These remain the same as in the original script.

def load_scan_series(dicom_folder_path):
    try:
        series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dicom_folder_path)
        if not series_ids: return None, None, None
        series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dicom_folder_path, series_ids[0])
        series_reader = sitk.ImageSeriesReader(); series_reader.SetFileNames(series_file_names)
        itkimage = series_reader.Execute()
        image_array = sitk.GetArrayFromImage(itkimage); origin = np.array(list(reversed(itkimage.GetOrigin()))); spacing = np.array(list(reversed(itkimage.GetSpacing())))
        return image_array, origin, spacing
    except Exception as e: print(f"Error reading DICOM {os.path.basename(dicom_folder_path)}: {e}"); return None, None, None

def resample(image, original_spacing, new_spacing=TARGET_SPACING):
    try:
        resize_factor = np.array(original_spacing) / np.array(new_spacing)
        new_real_shape = image.shape * resize_factor; new_shape = np.round(new_real_shape)
        real_resize_factor = new_shape / image.shape; actual_new_spacing = original_spacing / real_resize_factor
        resampled_image = scipy.ndimage.zoom(image, real_resize_factor, mode='nearest', order=1)
        return resampled_image, actual_new_spacing
    except Exception as e: print(f"Error resamping: {e}"); return None, None

def get_segmented_lungs(im_slice, hu_threshold=-320):
    if im_slice.ndim != 2: return im_slice
    binary = im_slice < hu_threshold; cleared = clear_border(binary)
    label_image = skimage_label(cleared); areas = [r.area for r in regionprops(label_image)]; areas.sort()
    area_threshold = areas[-2] if len(areas) >= 2 else (areas[-1] if len(areas) == 1 else 0)
    if area_threshold > 0:
        for region in regionprops(label_image):
            if region.area < area_threshold:
                for coordinates in region.coords: label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0; selem = disk(2); binary = binary_closing(binary, selem)
    selem_dilate = disk(5); final_mask = ndi.binary_dilation(binary, structure=selem_dilate)
    background_val = CLIP_BOUND_HU[0] - 1; segmented_slice = im_slice.copy()
    segmented_slice[final_mask == 0] = background_val
    return segmented_slice

def normalize_hu(image, clip_bounds=CLIP_BOUND_HU):
    min_bound, max_bound = clip_bounds
    image = np.clip(image, min_bound, max_bound)
    image = (image - min_bound) / (max_bound - min_bound)
    return image.astype(np.float32)

def zero_center(image, pixel_mean=PIXEL_MEAN):
    image = image - pixel_mean
    return image.astype(np.float32)

def resize_scan_to_target(image, target_shape=FINAL_SCAN_SIZE):
     if image.shape == target_shape: return image
     resize_factor = np.array(target_shape) / np.array(image.shape)
     try:
         resized_image = scipy.ndimage.zoom(image, resize_factor, order=1, mode='nearest')
         if resized_image.shape != target_shape: # Simple crop/pad correction
              current_shape = resized_image.shape
              diff = np.array(target_shape) - np.array(current_shape)
              pad = np.maximum(diff, 0); crop = np.maximum(-diff, 0)
              pad_width = tuple((p // 2, p - p // 2) for p in pad)
              resized_image = np.pad(resized_image, pad_width, mode='edge')
              crop_slice = tuple(slice(c // 2, s - (c - c // 2)) for c, s in zip(crop, resized_image.shape))
              resized_image = resized_image[crop_slice]
         if resized_image.shape != target_shape: print(f"ERROR: Resize failed. Shape {resized_image.shape}"); return None
         return resized_image.astype(np.float32)
     except Exception as e: print(f"Error resizing to target: {e}"); return None

# --- Full Preprocessing Pipeline (Unchanged Logic, uses correct paths) ---
def preprocess_scan_dsb(patient_id, input_base_path, output_base_path, force_preprocess=False):
    scan_folder_path = os.path.join(input_base_path, patient_id)
    output_filename = os.path.join(output_base_path, f"{patient_id}.npz") # Uses the passed output_base_path
    if os.path.exists(output_filename) and not force_preprocess: return True

    image, origin, spacing = load_scan_series(scan_folder_path)
    if image is None: return False
    resampled_image, new_spacing = resample(image, spacing, TARGET_SPACING)
    if resampled_image is None: del image; return False;
    del image;
    segmented_lungs = np.zeros_like(resampled_image, dtype=np.float32)
    for i in range(resampled_image.shape[0]): segmented_lungs[i] = get_segmented_lungs(resampled_image[i])
    del resampled_image;
    normalized_image = normalize_hu(segmented_lungs, clip_bounds=CLIP_BOUND_HU); del segmented_lungs;
    centered_image = zero_center(normalized_image, pixel_mean=PIXEL_MEAN); del normalized_image;
    final_image = resize_scan_to_target(centered_image, target_shape=FINAL_SCAN_SIZE); del centered_image;
    if final_image is None: return False
    try:
        np.savez_compressed(output_filename, image=final_image.astype(np.float32))
        return True
    except Exception as e: print(f"Error saving {patient_id}: {e}"); return False


# %% --- Preprocessing Execution (Using Selected IDs) ---

successful_processed_ids = []
failed_processed_ids = []

print(f"\nStarting preprocessing for {len(scans_to_process)} selected scans...")
start_time = time.time()

for patient_id in tqdm(scans_to_process, desc=f"Preprocessing {len(scans_to_process)} Selected Scans"):
    # <<< --- Pass the CORRECT output path --- >>>
    success = preprocess_scan_dsb(patient_id, DSB_PATH, PREPROCESSED_DSB_PATH, force_preprocess=False)
    if success:
        successful_processed_ids.append(patient_id)
    else:
        failed_processed_ids.append(patient_id)

end_time = time.time()
print(f"\nPreprocessing finished in {end_time - start_time:.2f} seconds.")
print(f"Successfully processed/found: {len(successful_processed_ids)} scans.")
if failed_processed_ids: print(f"Failed to process: {len(failed_processed_ids)} scans. IDs: {failed_processed_ids}")

# --- Final list for Dataset (only successfully processed from the selection) ---
final_patient_list = successful_processed_ids
if not final_patient_list: raise SystemExit("No scans processed successfully. Cannot continue.")

# Update patient_labels to only include successfully processed patients
patient_labels = {pid: patient_labels[pid] for pid in final_patient_list}
print(f"Final patient count for training/validation: {len(final_patient_list)}")
final_cancer_count = sum(1 for pid in final_patient_list if patient_labels[pid] == 1)
final_non_cancer_count = len(final_patient_list) - final_cancer_count
print(f"  Cancerous: {final_cancer_count}, Non-Cancerous: {final_non_cancer_count}")


# %% --- Dataset and DataLoader ---

class PatientLevelDataset(Dataset):
    # --- Dataset class remains UNCHANGED ---
    def __init__(self, patient_ids, labels_dict, preprocessed_path):
        self.patient_ids = patient_ids; self.labels_dict = labels_dict; self.preprocessed_path = preprocessed_path
    def __len__(self): return len(self.patient_ids)
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]; label = self.labels_dict[patient_id]
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")
        try:
            with np.load(scan_path) as npz_data: image = npz_data['image']
            image_tensor = torch.from_numpy(image).float().unsqueeze(0)
            label_tensor = torch.tensor(label, dtype=torch.float32)
            return image_tensor, label_tensor
        except Exception as e:
            print(f"ERROR loading {patient_id}: {e}"); dummy = torch.zeros((1, *FINAL_SCAN_SIZE), dtype=torch.float32)
            return dummy, torch.tensor(-1, dtype=torch.float32) # Error label

# --- Split Data (Train/Validation) ---
# <<< --- Stratification is important again --- >>>
train_ids, val_ids = train_test_split(
    final_patient_list,
    test_size=0.2,
    random_state=SEED,
    stratify=[patient_labels[pid] for pid in final_patient_list] # Stratify based on the labels of the processed subset
)
print(f"\nTraining patients: {len(train_ids)}")
print(f"Validation patients: {len(val_ids)}")

# --- Create Datasets and DataLoaders ---
# <<< --- Use the correct PREPROCESSED_DSB_PATH --- >>>
train_dataset = PatientLevelDataset(train_ids, patient_labels, PREPROCESSED_DSB_PATH)
val_dataset = PatientLevelDataset(val_ids, patient_labels, PREPROCESSED_DSB_PATH)

NUM_WORKERS = 0
print(f"Using {NUM_WORKERS} workers for DataLoader.")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())

# --- Check DataLoader Output ---
try:
    print("\nChecking DataLoader output (Limited Data)...")
    if len(train_loader) > 0:
        sample_batch, sample_labels = next(iter(train_loader))
        print(f"Sample batch shape: {sample_batch.shape}")
        print(f"Sample labels shape: {sample_labels.shape}")
        print(f"Sample labels: {sample_labels}") # Should contain 0s and 1s
        if torch.any(sample_labels == -1): print("WARNING: Error labels (-1) detected.")
    else: print("Train loader empty.")
except Exception as e: print(f"Error checking DataLoader: {e}")


# %% --- Model Definition (Unchanged) ---
class PatientLevel3DCNN(nn.Module):
    def __init__(self, input_shape=FINAL_SCAN_SIZE, input_channels=1, num_classes=1):
        super(PatientLevel3DCNN, self).__init__()

        # --- Reduced Convolutional Layers ---
        self.conv_layers = nn.Sequential(
            # Input: [B, 1, D, H, W] e.g. [B, 1, 96, 128, 128]
            nn.Conv3d(input_channels, 16, kernel_size=3, stride=1, padding=1), # Keep size
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2), # D/2, H/2, W/2 -> [B, 16, 48, 64, 64]

            nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2), # D/4, H/4, W/4 -> [B, 32, 24, 32, 32]

            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2), # D/8, H/8, W/8 -> [B, 64, 12, 16, 16]

            # <<< --- Removed the Conv3d(64, 128, ...) block --- >>>
            # nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm3d(128),
            # nn.ReLU(),
            # nn.MaxPool3d(kernel_size=2, stride=2), # D/16, H/16, W/16 -> [B, 128, 6, 8, 8]
            # <<< --- End Removal --- >>>


            # Adaptive pooling ensures fixed size before FC layers
            # Input to this layer is now [B, 64, 12, 16, 16]
            nn.AdaptiveMaxPool3d((2, 2, 2)) # Output size: [B, 64, 2, 2, 2]
        )

        # <<< --- Adjusted Flattened Size --- >>>
        # Calculate flattened size based on the output of the last pooling layer (64 channels * 2 * 2 * 2)
        flattened_size = 64 * 2 * 2 * 2  # 64 * 8 = 512
        # <<< --- End Adjustment --- >>>


        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            # <<< --- Adjusted Linear Layer Input Size --- >>>
            nn.Linear(flattened_size, 256), # Use the new flattened_size
            # <<< --- End Adjustment --- >>>
            nn.ReLU(),
            nn.Dropout(0.5), # Keep dropout for regularization
            nn.Linear(256, num_classes) # Output 1 logit
        )

    def forward(self, x):
        x = self.conv_layers(x)
        # print("Shape after conv_layers:", x.shape) # Optional: for debugging
        x = self.fc_layers(x)
        return x

patient_model = PatientLevel3DCNN(input_shape=FINAL_SCAN_SIZE, num_classes=1).to(DEVICE)
print(patient_model)
try:
    dummy_input = torch.randn(BATCH_SIZE, 1, *FINAL_SCAN_SIZE).to(DEVICE)
    output = patient_model(dummy_input); print(f"\nModel output shape: {output.shape}")
except Exception as e: print(f"\nError model test: {e}")


# %% --- Loss and Optimizer ---

# <<< --- Adjust pos_weight: Since we selected a balanced subset, weight should be close to 1.0 --- >>>
# Calculate weight based on the *actually processed* training samples if desired,
# but for a forced 50/50 split, weight=1 is reasonable.
# Let's calculate it based on the final training set composition for accuracy.
train_labels_list = [patient_labels[pid] for pid in train_ids]
count_0 = train_labels_list.count(0)
count_1 = train_labels_list.count(1)
if count_1 > 0 and count_0 > 0:
    pos_weight_val = count_0 / count_1
    print(f"Calculated positive weight for balanced training subset: {pos_weight_val:.4f}")
    pos_weight_tensor = torch.tensor([pos_weight_val], device=DEVICE)
else:
    print("Warning: Training set has only one class after split. Using default pos_weight=1.")
    pos_weight_tensor = torch.tensor([1.0], device=DEVICE) # Default if split results in one class

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
# <<< --- END pos_weight Adjustment --- >>>

optimizer = optim.Adam(patient_model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


# %% --- Training and Validation Functions (Unchanged Logic) ---
def train_one_epoch_patient(model, dataloader, criterion, optimizer, device, scaler):
    model.train(); running_loss = 0.0; total_samples = 0; correct_predictions = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for inputs, labels in progress_bar:
        valid_indices = labels != -1; inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device)
        if inputs.nelement() == 0: continue
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
            outputs = model(inputs); loss = criterion(outputs, labels)
        if torch.isnan(loss): print("NaN loss!"); continue
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5; correct_predictions += (preds == labels.bool()).sum().item()
        progress_bar.set_postfix(loss=loss.item())
    if total_samples == 0: return 0.0, 0.0
    return running_loss / total_samples, correct_predictions / total_samples

def validate_patient(model, dataloader, criterion, device):
    model.eval(); running_loss = 0.0; total_samples = 0; all_preds_proba = []; all_labels = []
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating", leave=False)
        for inputs, labels in progress_bar:
            valid_indices = labels != -1; inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device)
            if inputs.nelement() == 0: continue
            with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
                outputs = model(inputs); loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
            all_preds_proba.extend(torch.sigmoid(outputs).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    if total_samples == 0: return 0.0, np.array([]), np.array([])
    return running_loss / total_samples, np.array(all_labels).flatten(), np.array(all_preds_proba).flatten()


# %% --- Training Loop ---
print(f"\nStarting Training on Limited ({len(final_patient_list)}) Scan Dataset...")
best_val_loss = float('inf'); train_losses, val_losses, train_accs = [], [], []

# <<< --- Save model in the new directory --- >>>
MODEL_SAVE_PATH = os.path.join(PREPROCESSED_DSB_PATH, "patient_level_model_50_each_best.pth")

if torch.cuda.is_available(): torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    start_epoch_time = time.time()
    train_loss, train_acc = train_one_epoch_patient(patient_model, train_loader, criterion, optimizer, DEVICE, scaler)
    val_loss, val_labels_epoch, val_preds_proba_epoch = validate_patient(patient_model, val_loader, criterion, DEVICE)
    train_losses.append(train_loss); val_losses.append(val_loss); train_accs.append(train_acc)
    end_epoch_time = time.time(); epoch_duration = end_epoch_time - start_epoch_time

    val_acc_epoch = 0.0
    if len(val_labels_epoch) > 0: val_acc_epoch = accuracy_score(val_labels_epoch, (val_preds_proba_epoch > 0.5).astype(int))
    print(f"Epoch {epoch+1} Summary: Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc_epoch:.4f}")

    if val_loss < best_val_loss and len(val_labels_epoch) > 0:
        best_val_loss = val_loss
        try: torch.save(patient_model.state_dict(), MODEL_SAVE_PATH); print(f"  Best model saved to {MODEL_SAVE_PATH}")
        except Exception as e: print(f"Error saving model: {e}")
    if torch.cuda.is_available(): torch.cuda.empty_cache()

print("\nLimited Scan Dataset Training Finished.")


# %% --- Plot Training History ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1); plt.plot(range(1, EPOCHS + 1), train_losses, label='Train'); plt.plot(range(1, EPOCHS + 1), val_losses, label='Val')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Loss Curve (50 Each)'); plt.legend(); plt.grid(True)
plt.subplot(1, 2, 2); plt.plot(range(1, EPOCHS + 1), train_accs, label='Train Acc')
plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('Accuracy Curve (50 Each)'); plt.legend(); plt.grid(True)
plt.tight_layout(); plt.show()


# %% --- Model Evaluation ---
print("\nEvaluating Model on Limited Validation Set...")
if os.path.exists(MODEL_SAVE_PATH):
    try:
        if 'patient_model' not in locals(): patient_model = PatientLevel3DCNN(input_shape=FINAL_SCAN_SIZE, num_classes=1).to(DEVICE)
        patient_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best model from {MODEL_SAVE_PATH}")
    except Exception as e: print(f"Could not load best model: {e}. Using last epoch model.")
else: print("Best model file not found. Using last epoch model.")

val_loss_final, final_val_labels, final_val_preds_proba = validate_patient(patient_model, val_loader, criterion, DEVICE)

if len(final_val_labels) == 0: print("No valid validation predictions. Cannot evaluate.")
else:
    print(f"\nFinal Validation Loss: {val_loss_final:.4f}")
    final_val_preds_binary = (final_val_preds_proba > 0.5).astype(int)
    accuracy = accuracy_score(final_val_labels, final_val_preds_binary)
    precision = precision_score(final_val_labels, final_val_preds_binary, zero_division=0)
    recall = recall_score(final_val_labels, final_val_preds_binary, zero_division=0)
    f1 = f1_score(final_val_labels, final_val_preds_binary, zero_division=0)
    try: auc_roc = roc_auc_score(final_val_labels, final_val_preds_proba)
    except ValueError as e: print(f"AUC-ROC Error: {e}"); auc_roc = float('nan')

    print("\n--- Final Validation Metrics (Limited Data) ---")
    print(f"Accuracy:  {accuracy:.4f}"); print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}"); print(f"F1-Score:  {f1:.4f}"); print(f"AUC-ROC:   {auc_roc:.4f}")

    print("\nClassification Report (Limited Data):")
    target_names = ['Non-Cancer (0)', 'Cancer (1)']
    print(classification_report(final_val_labels, final_val_preds_binary, target_names=target_names, zero_division=0))

    print("\nConfusion Matrix (Limited Data):")
    cm = confusion_matrix(final_val_labels, final_val_preds_binary); disp = ConfusionMatrixDisplay(cm, display_labels=target_names)
    disp.plot(cmap=plt.cm.Blues); plt.show()

    if not np.isnan(auc_roc):
        fpr, tpr, _ = roc_curve(final_val_labels, final_val_preds_proba)
        plt.figure(figsize=(8, 6)); plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--'); plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve (Limited Data)')
        plt.legend(loc="lower right"); plt.grid(True); plt.show()
import torch.nn.functional as F
# <<< --- MODIFIED: New output path for the limited dataset --- >>>
PREPROCESSED_DSB_PATH = './preprocessed_dsb_50_each_64cube/' # Changed suffix to reflect size
# ---

# Preprocessing & Model Params
TARGET_SPACING = [1.5, 1.5, 1.5]
# <<< --- !!! CRITICAL: Changed FINAL_SCAN_SIZE to match Keras model input !!! --- >>>
FINAL_SCAN_SIZE = (64, 64, 64)
# <<< --- END CRITICAL CHANGE --- >>>
print(f"*** Using FINAL_SCAN_SIZE: {FINAL_SCAN_SIZE} to match Simp3D model ***")

CLIP_BOUND_HU = [-1000.0, 400.0]
PIXEL_MEAN = 0.25

# Training Params
NUM_CLASSES = 1 # For BCEWithLogitsLoss
BATCH_SIZE = 4  # Adjust based on GPU memory with the new model
LEARNING_RATE = 4e-5 # Using the LR from the Keras model definition
EPOCHS = 150
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Scan limit per class
SCAN_LIMIT_PER_CLASS = 50

# Ensure output directory exists
os.makedirs(PREPROCESSED_DSB_PATH, exist_ok=True) # Use new path

# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# %% --- Data Loading and Selection ---
# (This section remains the same as your provided code)
# It will select 50 cancer and 50 non-cancer scans if available.
# ... (Keep the existing data loading/selection code here) ...
print(f"--- Loading Data and Selecting up to {SCAN_LIMIT_PER_CLASS} Scans Per Class ---")

# --- Verify Paths ---
if not os.path.isdir(DSB_PATH): raise SystemExit(f"ERROR: DSB Scans path not found: {DSB_PATH}")
if not os.path.isfile(DSB_LABELS_CSV): raise SystemExit(f"ERROR: DSB Labels CSV not found: {DSB_LABELS_CSV}")
print(f"DSB Scans path: {DSB_PATH}")
print(f"DSB Labels CSV: {DSB_LABELS_CSV}")

# --- Load Labels ---
try:
    dsb_labels_df = pd.read_csv(DSB_LABELS_CSV)
    dsb_labels_df = dsb_labels_df.rename(columns={'id': 'patient_id'})
    print(f"Loaded {len(dsb_labels_df)} total DSB patient labels.")
    if 'cancer' not in dsb_labels_df.columns: raise ValueError("Labels CSV needs 'cancer' column.")
    print("Original label distribution:\n", dsb_labels_df['cancer'].value_counts())
    patient_labels_all = dsb_labels_df.set_index('patient_id')['cancer'].to_dict()
except Exception as e: raise SystemExit(f"ERROR: Failed to load labels CSV: {e}")

# --- Check Scan Folders ---
scan_folders = [f for f in os.listdir(DSB_PATH) if os.path.isdir(os.path.join(DSB_PATH, f))]
print(f"Found {len(scan_folders)} potential patient scan folders.")
found_scan_ids = set(scan_folders)

# --- Find Common IDs (patients with both label and scan folder) ---
labeled_patient_ids_all = set(dsb_labels_df['patient_id'])
common_ids_all = labeled_patient_ids_all.intersection(found_scan_ids)
print(f"Found {len(common_ids_all)} patient IDs with both labels and scan folders.")
if not common_ids_all: raise SystemExit("No matching patient IDs found. Cannot continue.")

# --- Separate Common IDs by Class ---
common_ids_cancer = []
common_ids_non_cancer = []
for pid in common_ids_all:
    label = patient_labels_all.get(pid)
    if label == 1: common_ids_cancer.append(pid)
    elif label == 0: common_ids_non_cancer.append(pid)

print(f"Available Cancerous scans with labels: {len(common_ids_cancer)}")
print(f"Available Non-Cancerous scans with labels: {len(common_ids_non_cancer)}")

# --- Shuffle and Select Limited Number Per Class ---
random.shuffle(common_ids_cancer); random.shuffle(common_ids_non_cancer)
selected_cancer_ids = common_ids_cancer[:SCAN_LIMIT_PER_CLASS]
selected_non_cancer_ids = common_ids_non_cancer[:SCAN_LIMIT_PER_CLASS]
print(f"Selected {len(selected_cancer_ids)} Cancerous scans.")
print(f"Selected {len(selected_non_cancer_ids)} Non-Cancerous scans.")

# --- Combine selected IDs for processing ---
scans_to_process = selected_cancer_ids + selected_non_cancer_ids
random.shuffle(scans_to_process)
print(f"Total scans selected for preprocessing: {len(scans_to_process)}")
patient_labels = {pid: patient_labels_all[pid] for pid in scans_to_process}


# %% --- Preprocessing Functions ---
# (Functions load_scan_series, resample, get_segmented_lungs,
#  normalize_hu, zero_center remain unchanged)
def load_scan_series(dicom_folder_path):
    try:
        series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dicom_folder_path)
        if not series_ids: return None, None, None
        series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dicom_folder_path, series_ids[0])
        series_reader = sitk.ImageSeriesReader(); series_reader.SetFileNames(series_file_names)
        itkimage = series_reader.Execute()
        image_array = sitk.GetArrayFromImage(itkimage); origin = np.array(list(reversed(itkimage.GetOrigin()))); spacing = np.array(list(reversed(itkimage.GetSpacing())))
        return image_array, origin, spacing
    except Exception as e: print(f"Error reading DICOM {os.path.basename(dicom_folder_path)}: {e}"); return None, None, None

def resample(image, original_spacing, new_spacing=TARGET_SPACING):
    try:
        resize_factor = np.array(original_spacing) / np.array(new_spacing)
        new_real_shape = image.shape * resize_factor; new_shape = np.round(new_real_shape)
        real_resize_factor = new_shape / image.shape; actual_new_spacing = original_spacing / real_resize_factor
        resampled_image = scipy.ndimage.zoom(image, real_resize_factor, mode='nearest', order=1)
        return resampled_image, actual_new_spacing
    except Exception as e: print(f"Error resamping: {e}"); return None, None

def get_segmented_lungs(im_slice, hu_threshold=-320):
    if im_slice.ndim != 2: return im_slice
    binary = im_slice < hu_threshold; cleared = clear_border(binary)
    label_image = skimage_label(cleared); areas = [r.area for r in regionprops(label_image)]; areas.sort()
    area_threshold = areas[-2] if len(areas) >= 2 else (areas[-1] if len(areas) == 1 else 0)
    if area_threshold > 0:
        for region in regionprops(label_image):
            if region.area < area_threshold:
                for coordinates in region.coords: label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0; selem = disk(2); binary = binary_closing(binary, selem)
    selem_dilate = disk(5); final_mask = ndi.binary_dilation(binary, structure=selem_dilate)
    background_val = CLIP_BOUND_HU[0] - 1; segmented_slice = im_slice.copy()
    segmented_slice[final_mask == 0] = background_val
    return segmented_slice

def normalize_hu(image, clip_bounds=CLIP_BOUND_HU):
    min_bound, max_bound = clip_bounds
    image = np.clip(image, min_bound, max_bound)
    image = (image - min_bound) / (max_bound - min_bound)
    return image.astype(np.float32)

def zero_center(image, pixel_mean=PIXEL_MEAN):
    image = image - pixel_mean
    return image.astype(np.float32)


# --- Modified resize_scan_to_target to use the new FINAL_SCAN_SIZE ---
def resize_scan_to_target(image, target_shape=FINAL_SCAN_SIZE): # Uses global FINAL_SCAN_SIZE
     if image.shape == target_shape: return image
     resize_factor = np.array(target_shape) / np.array(image.shape)
     try:
         # Using order=1 (bilinear) for resizing, 'nearest' might lose info
         resized_image = scipy.ndimage.zoom(image, resize_factor, order=1, mode='nearest')
         # Simple crop/pad if zoom result is slightly off due to rounding
         if resized_image.shape != target_shape:
              current_shape = resized_image.shape
              diff = np.array(target_shape) - np.array(current_shape)
              pad = np.maximum(diff, 0); crop = np.maximum(-diff, 0)
              pad_width = tuple((p // 2, p - p // 2) for p in pad)
              resized_image = np.pad(resized_image, pad_width, mode='edge') # Pad with edge value
              crop_slice = tuple(slice(c // 2, s - (c - c // 2)) for c, s in zip(crop, resized_image.shape))
              resized_image = resized_image[crop_slice]
         # Final check
         if resized_image.shape != target_shape:
             print(f"ERROR: Resize failed. Target: {target_shape}, Got: {resized_image.shape}")
             return None
         return resized_image.astype(np.float32)
     except Exception as e:
         print(f"Error resizing image of shape {image.shape} to target {target_shape}: {e}")
         return None


# --- Full Preprocessing Pipeline (Unchanged Logic, uses correct paths and FINAL_SCAN_SIZE) ---
def preprocess_scan_dsb(patient_id, input_base_path, output_base_path, force_preprocess=False):
    scan_folder_path = os.path.join(input_base_path, patient_id)
    output_filename = os.path.join(output_base_path, f"{patient_id}.npz")
    if os.path.exists(output_filename) and not force_preprocess:
        # print(f"Skipping {patient_id}, already preprocessed.")
        return True # Already exists

    # print(f"Processing {patient_id}...")
    image, origin, spacing = load_scan_series(scan_folder_path)
    if image is None: return False

    # print(f"  Original shape: {image.shape}, spacing: {np.round(spacing, 2)}")
    resampled_image, new_spacing = resample(image, spacing, TARGET_SPACING)
    if resampled_image is None: del image; return False
    # print(f"  Resampled shape: {resampled_image.shape}, spacing: {np.round(new_spacing, 2)}")
    del image

    # Segment lungs slice by slice
    segmented_lungs = np.zeros_like(resampled_image, dtype=np.float32)
    for i in range(resampled_image.shape[0]):
        segmented_lungs[i] = get_segmented_lungs(resampled_image[i])
    del resampled_image

    # Normalize HU
    normalized_image = normalize_hu(segmented_lungs, clip_bounds=CLIP_BOUND_HU)
    del segmented_lungs

    # Zero center
    centered_image = zero_center(normalized_image, pixel_mean=PIXEL_MEAN)
    del normalized_image

    # Resize to final target shape (e.g., 64x64x64)
    final_image = resize_scan_to_target(centered_image, target_shape=FINAL_SCAN_SIZE) # Uses the global
    del centered_image
    if final_image is None: return False
    # print(f"  Final shape: {final_image.shape}")

    # Save compressed
    try:
        np.savez_compressed(output_filename, image=final_image.astype(np.float32))
        # print(f"  Saved to {output_filename}")
        return True
    except Exception as e:
        print(f"Error saving {patient_id}: {e}")
        return False


# %% --- Preprocessing Execution (Using Selected IDs) ---
# Ensure you re-run this if you changed FINAL_SCAN_SIZE
print(f"\n--- Starting Preprocessing for {len(scans_to_process)} scans to size {FINAL_SCAN_SIZE} ---")
print(f"Output directory: {PREPROCESSED_DSB_PATH}")
# Set force_preprocess=True if you need to overwrite existing files with the new size
FORCE_REPROCESS = False # Set to True to re-process all selected scans

successful_processed_ids = []
failed_processed_ids = []
start_time = time.time()

for patient_id in tqdm(scans_to_process, desc=f"Preprocessing {len(scans_to_process)} Scans ({FINAL_SCAN_SIZE})"):
    success = preprocess_scan_dsb(patient_id, DSB_PATH, PREPROCESSED_DSB_PATH, force_preprocess=FORCE_REPROCESS)
    if success:
        successful_processed_ids.append(patient_id)
    else:
        failed_processed_ids.append(patient_id)

end_time = time.time()
print(f"\nPreprocessing finished in {end_time - start_time:.2f} seconds.")
print(f"Successfully processed/found: {len(successful_processed_ids)} scans.")
if failed_processed_ids: print(f"Failed to process: {len(failed_processed_ids)} scans. IDs: {failed_processed_ids}")

# --- Final list for Dataset (only successfully processed from the selection) ---
final_patient_list = successful_processed_ids
if not final_patient_list: raise SystemExit("No scans processed successfully. Cannot continue.")

# Update patient_labels to only include successfully processed patients
patient_labels = {pid: patient_labels[pid] for pid in final_patient_list}
print(f"Final patient count for training/validation: {len(final_patient_list)}")
final_cancer_count = sum(1 for pid in final_patient_list if patient_labels[pid] == 1)
final_non_cancer_count = len(final_patient_list) - final_cancer_count
print(f"  Cancerous: {final_cancer_count}, Non-Cancerous: {final_non_cancer_count}")


# %% --- Dataset and DataLoader ---

class PatientLevelDataset(Dataset):
    # --- Dataset class remains UNCHANGED ---
    def __init__(self, patient_ids, labels_dict, preprocessed_path):
        self.patient_ids = patient_ids; self.labels_dict = labels_dict; self.preprocessed_path = preprocessed_path
        self.target_size = FINAL_SCAN_SIZE # Store target size for checks

    def __len__(self): return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]; label = self.labels_dict[patient_id]
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")
        try:
            with np.load(scan_path) as npz_data: image = npz_data['image']
            # Verify shape
            if image.shape != self.target_size:
                print(f"ERROR: Shape mismatch for {patient_id}. Expected {self.target_size}, got {image.shape}. Skipping.")
                dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
                return dummy, torch.tensor(-1, dtype=torch.float32) # Error label

            image_tensor = torch.from_numpy(image).float().unsqueeze(0) # Add channel dim
            label_tensor = torch.tensor(label, dtype=torch.float32)
            return image_tensor, label_tensor
        except FileNotFoundError:
             print(f"ERROR: File not found {scan_path}. Skipping.")
             dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
             return dummy, torch.tensor(-1, dtype=torch.float32) # Error label
        except Exception as e:
            print(f"ERROR loading {patient_id}: {e}. Skipping.")
            dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
            return dummy, torch.tensor(-1, dtype=torch.float32) # Error label

# --- Split Data (Train/Validation) ---
train_ids, val_ids = train_test_split(
    final_patient_list,
    test_size=0.2,
    random_state=SEED,
    stratify=[patient_labels[pid] for pid in final_patient_list] # Stratify
)
print(f"\nTraining patients: {len(train_ids)}")
print(f"Validation patients: {len(val_ids)}")

# --- Create Datasets and DataLoaders ---
train_dataset = PatientLevelDataset(train_ids, patient_labels, PREPROCESSED_DSB_PATH)
val_dataset = PatientLevelDataset(val_ids, patient_labels, PREPROCESSED_DSB_PATH)

NUM_WORKERS = 0 # Set to 0 for easier debugging if issues arise
print(f"Using {NUM_WORKERS} workers for DataLoader.")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())

# --- Check DataLoader Output ---
try:
    print("\nChecking DataLoader output...")
    if len(train_loader) > 0:
        sample_batch, sample_labels = next(iter(train_loader))
        print(f"Sample batch shape: {sample_batch.shape}") # Should be [B, 1, 64, 64, 64]
        print(f"Sample labels shape: {sample_labels.shape}")
        print(f"Sample labels: {sample_labels}")
        if torch.any(sample_labels == -1): print("WARNING: Error labels (-1) detected in first batch.")
    else: print("Train loader empty.")
    if len(val_loader) > 0:
        sample_batch_val, _ = next(iter(val_loader))
        print(f"Validation batch shape: {sample_batch_val.shape}")
    else: print("Validation loader empty.")
except Exception as e: print(f"Error checking DataLoader: {e}")


# %% --- Model Definition (Simp3DNet - PyTorch version of get_simp3d) ---

# Filter numbers from the Keras model definition
num_filters = [16, 32, 64, 128, 256, 1028]

class Simp3DNet(nn.Module):
    def __init__(self, input_channels=1, num_classes=1): # num_classes=1 for BCEWithLogitsLoss
        super(Simp3DNet, self).__init__()

        # Block 1: Input (B, 1, 64, 64, 64)
        # Conv 9x9x9, valid -> (B, 16, 56, 56, 56)
        self.conv1a = nn.Conv3d(input_channels, num_filters[0], kernel_size=9, padding=0)
        self.bn1a = nn.BatchNorm3d(num_filters[0])
        # Conv 3x3x3, valid -> (B, 16, 54, 54, 54)
        self.conv1b = nn.Conv3d(num_filters[0], num_filters[0], kernel_size=3, padding=0)
        self.bn1b = nn.BatchNorm3d(num_filters[0])
        # Conv 5x5x5, valid -> (B, 16, 50, 50, 50)
        self.conv1c = nn.Conv3d(num_filters[0], num_filters[0], kernel_size=5, padding=0)
        self.bn1c = nn.BatchNorm3d(num_filters[0])
        # Pool 2x2x2 -> (B, 16, 25, 25, 25)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)

        # Block 2
        # Conv 3x3x3, valid -> (B, 32, 23, 23, 23)
        self.conv2a = nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3, padding=0)
        self.bn2a = nn.BatchNorm3d(num_filters[1])
        # Conv 3x3x3, valid -> (B, 32, 21, 21, 21)
        self.conv2b = nn.Conv3d(num_filters[1], num_filters[1], kernel_size=3, padding=0)
        self.bn2b = nn.BatchNorm3d(num_filters[1])
        # Pool 2x2x2 -> (B, 32, 10, 10, 10) - Note: floor(21/2) = 10
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)

        # Block 3
        # Conv 3x3x3, valid -> (B, 64, 8, 8, 8)
        self.conv3a = nn.Conv3d(num_filters[1], num_filters[2], kernel_size=3, padding=0)
        self.bn3a = nn.BatchNorm3d(num_filters[2])
        # Conv 3x3x3, valid -> (B, 64, 6, 6, 6)
        self.conv3b = nn.Conv3d(num_filters[2], num_filters[2], kernel_size=3, padding=0)
        self.bn3b = nn.BatchNorm3d(num_filters[2])
        # No pooling

        # Block 4
        # Conv 3x3x3, valid -> (B, 128, 4, 4, 4)
        self.conv4a = nn.Conv3d(num_filters[2], num_filters[3], kernel_size=3, padding=0)
        self.bn4a = nn.BatchNorm3d(num_filters[3])
        # No pooling

        # Flatten Layer
        self.flatten = nn.Flatten()

        # Calculate flattened features: 128 filters * 4 * 4 * 4 volume
        flattened_features = num_filters[3] * 4 * 4 * 4 # 128 * 64 = 8192
        # print(f"Calculated flattened features: {flattened_features}") # For debugging

        # Dense Layers
        self.fc1 = nn.Linear(flattened_features, 256)
        # Keras model applies BN *after* activation for Dense layer. Replicate this.
        # Use BatchNorm1d for features after flattening.
        self.bn_fc1 = nn.BatchNorm1d(256)
        # Output layer: 1 logit for BCEWithLogitsLoss
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # Block 1 - Using Conv -> BN -> ReLU pattern
        x = F.relu(self.bn1a(self.conv1a(x)))
        x = F.relu(self.bn1b(self.conv1b(x)))
        x = F.relu(self.bn1c(self.conv1c(x)))
        x = self.pool1(x)

        # Block 2
        x = F.relu(self.bn2a(self.conv2a(x)))
        x = F.relu(self.bn2b(self.conv2b(x)))
        x = self.pool2(x)

        # Block 3
        x = F.relu(self.bn3a(self.conv3a(x)))
        x = F.relu(self.bn3b(self.conv3b(x)))

        # Block 4
        x = F.relu(self.bn4a(self.conv4a(x)))

        # Flatten and Dense Layers
        x = self.flatten(x)
        # print("Shape after flatten:", x.shape) # Debug
        x = self.fc1(x)
        # print("Shape after fc1:", x.shape) # Debug
        # Apply ReLU then BN, matching Keras Dense -> activation -> BN
        x = F.relu(x)
        x = self.bn_fc1(x)
        # Final output layer (logits)
        x = self.fc2(x)
        # print("Shape after fc2:", x.shape) # Debug
        return x

# Instantiate the new model
patient_model = Simp3DNet(input_channels=1, num_classes=NUM_CLASSES).to(DEVICE)
print("\n--- Using Simp3DNet Model ---")
print(patient_model)

# Test model with a dummy input
try:
    # Use the correct FINAL_SCAN_SIZE for the dummy input
    dummy_input = torch.randn(BATCH_SIZE, 1, *FINAL_SCAN_SIZE).to(DEVICE)
    print(f"Dummy input shape: {dummy_input.shape}")
    output = patient_model(dummy_input)
    print(f"Model output shape: {output.shape}") # Should be [B, 1]
except Exception as e:
    print(f"\nError during model test forward pass: {e}")
    # If you get size mismatch errors here, double-check flattened_features calculation
    # based on the FINAL_SCAN_SIZE and padding='valid' convolution steps.

# %% --- Loss and Optimizer ---

# Calculate positive weight based on the *actual* training set composition
train_labels_list = [patient_labels[pid] for pid in train_ids]
count_0 = train_labels_list.count(0)
count_1 = train_labels_list.count(1)
if count_1 > 0 and count_0 > 0:
    pos_weight_val = count_0 / count_1
    print(f"Calculated positive weight for training subset ({count_1} pos / {count_0} neg): {pos_weight_val:.4f}")
    pos_weight_tensor = torch.tensor([pos_weight_val], device=DEVICE)
elif count_1 > 0: # Only positive samples
     print("Warning: Training set only contains positive samples. Using pos_weight=1.")
     pos_weight_tensor = torch.tensor([1.0], device=DEVICE)
elif count_0 > 0: # Only negative samples
     print("Warning: Training set only contains negative samples. Using pos_weight=1.")
     pos_weight_tensor = torch.tensor([1.0], device=DEVICE)
else: # Empty training set?
    print("Warning: Training set appears empty. Using pos_weight=1.")
    pos_weight_tensor = torch.tensor([1.0], device=DEVICE)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
# Use Adam optimizer with the LR from the Keras model definition
optimizer = optim.Adam(patient_model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


# %% --- Training and Validation Functions (Unchanged Logic) ---
# These functions work with single logit output and BCEWithLogitsLoss
def train_one_epoch_patient(model, dataloader, criterion, optimizer, device, scaler):
    model.train(); running_loss = 0.0; total_samples = 0; correct_predictions = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for inputs, labels in progress_bar:
        # Filter out error labels before moving to device
        valid_indices = labels != -1
        if not torch.any(valid_indices): continue # Skip batch if all are errors
        inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device) # Add dim for BCE loss

        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print("NaN loss encountered during training! Skipping batch.")
            continue # Skip this batch

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5 # Get binary predictions
        correct_predictions += (preds == labels.bool()).sum().item()
        progress_bar.set_postfix(loss=loss.item())

    if total_samples == 0: return 0.0, 0.0
    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

def validate_patient(model, dataloader, criterion, device):
    model.eval(); running_loss = 0.0; total_samples = 0; all_preds_proba = []; all_labels = []
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating", leave=False)
        for inputs, labels in progress_bar:
            # Filter out error labels
            valid_indices = labels != -1
            if not torch.any(valid_indices): continue
            inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].to(device) # Keep labels flat for eval metrics

            with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
                outputs = model(inputs)
                # Calculate loss with unsqueezed labels for consistency
                loss = criterion(outputs, labels.unsqueeze(1))

            if torch.isnan(loss):
                 print("NaN loss encountered during validation!")
                 # Don't add to running loss, but record predictions if needed? Or skip batch?
                 continue

            running_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

            # Store predictions (probabilities) and true labels for metrics
            all_preds_proba.extend(torch.sigmoid(outputs).cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    if total_samples == 0: return 0.0, np.array([]), np.array([])
    val_loss = running_loss / total_samples
    return val_loss, np.array(all_labels), np.array(all_preds_proba)


# %% --- Training Loop ---
print(f"\nStarting Training with Simp3DNet Model ({len(final_patient_list)} scans)...")
best_val_loss = float('inf'); train_losses, val_losses, train_accs = [], [], []

# <<< --- Save model in the new directory with appropriate name --- >>>
MODEL_SAVE_PATH = os.path.join(PREPROCESSED_DSB_PATH, "simp3dnet_model_50_each_64cube_best.pth")

if torch.cuda.is_available(): torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch_patient(patient_model, train_loader, criterion, optimizer, DEVICE, scaler)
    val_loss, val_labels_epoch, val_preds_proba_epoch = validate_patient(patient_model, val_loader, criterion, DEVICE)

    train_losses.append(train_loss); val_losses.append(val_loss); train_accs.append(train_acc)

    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time

    # Calculate validation accuracy for reporting
    val_acc_epoch = 0.0
    if len(val_labels_epoch) > 0:
        val_preds_binary_epoch = (val_preds_proba_epoch > 0.5).astype(int)
        val_acc_epoch = accuracy_score(val_labels_epoch, val_preds_binary_epoch)

    print(f"Epoch {epoch+1} Summary: Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc_epoch:.4f}")

    # Save best model based on validation loss
    if val_loss < best_val_loss and len(val_labels_epoch) > 0: # Ensure validation wasn't empty
        best_val_loss = val_loss
        try:
            torch.save(patient_model.state_dict(), MODEL_SAVE_PATH)
            print(f"  Best model saved to {MODEL_SAVE_PATH} (Val Loss: {best_val_loss:.4f})")
        except Exception as e:
            print(f"Error saving model: {e}")

    if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache per epoch

print("\nSimp3DNet Model Training Finished.")


# %% --- Plot Training History ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, EPOCHS + 1), train_losses, label='Train Loss')
plt.plot(range(1, EPOCHS + 1), val_losses, label='Val Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Simp3DNet Loss Curve (64cube)'); plt.legend(); plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, EPOCHS + 1), train_accs, label='Train Acc')
# You might want to plot validation accuracy too if calculated per epoch
# plt.plot(range(1, EPOCHS + 1), val_accs_epoch, label='Val Acc') # Need to store val_acc_epoch
plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('Simp3DNet Train Accuracy Curve (64cube)'); plt.legend(); plt.grid(True)

plt.tight_layout()
# Save the plot
plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "simp3dnet_training_curves_64cube.png")
plt.savefig(plot_save_path)
print(f"Training curves saved to {plot_save_path}")
plt.show()


# %% --- Model Evaluation ---
print("\nEvaluating Simp3DNet Model on Validation Set...")
# Load the best model for evaluation
if os.path.exists(MODEL_SAVE_PATH):
    try:
        # Re-initialize model structure before loading state_dict
        patient_model_eval = Simp3DNet(input_channels=1, num_classes=NUM_CLASSES).to(DEVICE)
        patient_model_eval.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best model from {MODEL_SAVE_PATH} for evaluation.")
        patient_model_eval.eval() # Set to evaluation mode
    except Exception as e:
        print(f"Could not load best model from {MODEL_SAVE_PATH}: {e}. Evaluating last epoch model instead.")
        # Ensure the current model is in eval mode
        patient_model.eval()
        patient_model_eval = patient_model # Use the model from the end of training
else:
    print(f"Best model file not found at {MODEL_SAVE_PATH}. Evaluating model from the end of training.")
    patient_model.eval() # Set to evaluation mode
    patient_model_eval = patient_model

# Perform validation using the loaded/final model
val_loss_final, final_val_labels, final_val_preds_proba = validate_patient(patient_model_eval, val_loader, criterion, DEVICE)

if len(final_val_labels) == 0:
    print("No valid validation predictions were generated. Cannot evaluate metrics.")
else:
    print(f"\nFinal Validation Loss: {val_loss_final:.4f}")

    # Calculate metrics using a 0.5 threshold
    final_val_preds_binary = (final_val_preds_proba > 0.5).astype(int)

    accuracy = accuracy_score(final_val_labels, final_val_preds_binary)
    # Use zero_division=0 to avoid warnings if a class has no predictions/labels
    precision = precision_score(final_val_labels, final_val_preds_binary, zero_division=0)
    recall = recall_score(final_val_labels, final_val_preds_binary, zero_division=0)
    f1 = f1_score(final_val_labels, final_val_preds_binary, zero_division=0)

    # Calculate AUC-ROC, handle cases with only one class present
    auc_roc = float('nan') # Default to NaN
    if len(np.unique(final_val_labels)) > 1: # Check if both classes are in the true labels
         try:
             auc_roc = roc_auc_score(final_val_labels, final_val_preds_proba)
         except ValueError as e:
              print(f"AUC-ROC Calculation Error: {e}") # Should not happen if len(unique)>1, but good practice
    else:
         print("AUC-ROC cannot be calculated: Only one class present in validation labels.")


    print("\n--- Final Validation Metrics (Simp3DNet, 64cube) ---")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print(f"AUC-ROC:   {auc_roc:.4f}")

    print("\nClassification Report (Simp3DNet, 64cube):")
    target_names = ['Non-Cancer (0)', 'Cancer (1)']
    # Ensure labels are integers for classification_report
    print(classification_report(final_val_labels.astype(int), final_val_preds_binary, target_names=target_names, zero_division=0))

    print("\nConfusion Matrix (Simp3DNet, 64cube):")
    try:
        cm = confusion_matrix(final_val_labels.astype(int), final_val_preds_binary, labels=[0, 1]) # Ensure labels are 0 and 1
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
        disp.plot(cmap=plt.cm.Blues)
        # Save the plot
        cm_save_path = os.path.join(PREPROCESSED_DSB_PATH, "simp3dnet_confusion_matrix_64cube.png")
        plt.savefig(cm_save_path)
        print(f"Confusion matrix saved to {cm_save_path}")
        plt.show()
    except Exception as e:
        print(f"Error displaying confusion matrix: {e}")
        print("Raw CM data:", cm) # Print raw data if plot fails

    # Plot ROC Curve if AUC is valid
    if not np.isnan(auc_roc):
        fpr, tpr, _ = roc_curve(final_val_labels, final_val_preds_proba)
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve (Simp3DNet, 64cube)')
        plt.legend(loc="lower right"); plt.grid(True)
         # Save the plot
        roc_save_path = os.path.join(PREPROCESSED_DSB_PATH, "simp3dnet_roc_curve_64cube.png")
        plt.savefig(roc_save_path)
        print(f"ROC curve saved to {roc_save_path}")
        plt.show()
# %% --- Model Definition (UNetPlusPlus_SE_Transformer) ---
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Building Blocks ---

class SEBlock3D(nn.Module):
    """
    3D Squeeze-and-Excitation Block.
    """
    def __init__(self, channels, reduction=16):
        super(SEBlock3D, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool3d(1)
        self.excitation = nn.Sequential(
            nn.Conv3d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.squeeze(x)
        y = self.excitation(y)
        return x * y.expand_as(x)

class ConvBlock3D(nn.Module):
    """
    Standard 3D Convolutional Block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU (+ optional SE).
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, use_se=True, se_reduction=16):
        super(ConvBlock3D, self).__init__()
        self.use_se = use_se
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        if self.use_se:
            self.se = SEBlock3D(out_channels, reduction=se_reduction)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        if self.use_se:
            x = self.se(x)
        return x

class TransformerEncoderLayer3D(nn.Module):
    """
    A single layer of a 3D Transformer Encoder.
    Operates on sequences of patches (B, N_patches, E_dim).
    """
    def __init__(self, embed_dim, num_heads, ff_dim_factor=4, dropout=0.1):
        super(TransformerEncoderLayer3D, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * ff_dim_factor),
            nn.GELU(), # GELU is common in Transformers
            nn.Dropout(dropout),
            nn.Linear(embed_dim * ff_dim_factor, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Multi-head Self-attention
        attn_output, _ = self.attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(attn_output) # Add & Norm
        src = self.norm1(src)

        # Feed-forward Network
        ffn_output = self.ffn(src)
        src = src + self.dropout2(ffn_output) # Add & Norm
        src = self.norm2(src)
        return src

class TransformerBottleneck(nn.Module):
    """
    Transformer Bottleneck for 3D U-Net.
    Takes feature map (B, C_in, D, H, W), projects to (B, N_patches, E_dim),
    applies Transformer layers, and reshapes back to (B, E_dim, D, H, W).
    """
    def __init__(self, in_channels, embed_dim, num_layers, num_heads, bottleneck_spatial_dims, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.bottleneck_spatial_dims = bottleneck_spatial_dims # (D_b, H_b, W_b) at bottleneck
        
        num_patches = bottleneck_spatial_dims[0] * bottleneck_spatial_dims[1] * bottleneck_spatial_dims[2]

        # Project input channels to embedding dimension if they differ
        if in_channels != embed_dim:
            self.patch_projection = nn.Conv3d(in_channels, embed_dim, kernel_size=1)
        else:
            self.patch_projection = nn.Identity() # No projection needed if channels match

        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim)) # Learnable positional embedding
        self.dropout_pos = nn.Dropout(dropout)

        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer3D(embed_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.norm_out = nn.LayerNorm(embed_dim) # Final normalization on patch embeddings

    def forward(self, x):
        # x shape: (B, C_in, D_b, H_b, W_b)
        x = self.patch_projection(x) # (B, E_dim, D_b, H_b, W_b)
        
        b, e, d, h, w = x.shape
        if (d,h,w) != self.bottleneck_spatial_dims:
            raise ValueError(f"Spatial dimensions mismatch in TransformerBottleneck. Expected {self.bottleneck_spatial_dims}, got {(d,h,w)}")
        if e != self.embed_dim:
             raise ValueError(f"Embedding dimension mismatch after projection. Expected {self.embed_dim}, got {e}")

        # Flatten spatial dimensions to create sequence of patches
        x = x.flatten(2) # (B, E_dim, N_patches) where N_patches = D_b*H_b*W_b
        x = x.transpose(1, 2) # (B, N_patches, E_dim) - batch_first for MHA

        # Add positional embedding
        x = x + self.pos_embed
        x = self.dropout_pos(x)

        # Pass through Transformer layers
        for layer in self.transformer_layers:
            x = layer(x)
        
        x = self.norm_out(x)

        # Reshape back to (B, E_dim, D_b, H_b, W_b)
        x = x.transpose(1, 2) # (B, E_dim, N_patches)
        x = x.view(b, e, d, h, w)
        return x

# --- Main Model: UNetPlusPlus_SE_Transformer ---
class UNetPlusPlus_SE_Transformer(nn.Module):
    def __init__(self, input_scan_size, in_channels=1, num_classes=1, initial_filters=16, depth=4,
                 use_se=True, transformer_embed_dim=256, transformer_layers=2,
                 transformer_heads=8, transformer_dropout=0.1, final_fc_units=128):
        super(UNetPlusPlus_SE_Transformer, self).__init__()
        
        if not (len(input_scan_size) == 3 and all(isinstance(s, int) for s in input_scan_size)):
            raise ValueError("input_scan_size must be a tuple of 3 integers (D, H, W)")

        self.depth = depth # Number of pooling operations. Levels = depth + 1.
        self.use_se = use_se
        nf = initial_filters

        # Encoder path (X_i,0)
        self.pools = nn.ModuleList()
        self.encoder_blocks = nn.ModuleList()
        encoder_output_channels = [] 
        
        current_channels_enc = in_channels
        for i in range(depth + 1): # Iterates 0 to depth (e.g., 0, 1, 2, 3, 4 if depth=4)
            out_ch_enc = nf * (2**i)
            conv = ConvBlock3D(current_channels_enc, out_ch_enc, use_se=use_se)
            self.encoder_blocks.append(conv)
            encoder_output_channels.append(out_ch_enc)
            if i < depth: 
                self.pools.append(nn.MaxPool3d(2, 2))
            current_channels_enc = out_ch_enc
        
        # Transformer Bottleneck
        bottleneck_in_channels = encoder_output_channels[-1] # Channels of X_depth,0
        s_d, s_h, s_w = (input_scan_size[0]//(2**depth), 
                         input_scan_size[1]//(2**depth), 
                         input_scan_size[2]//(2**depth))
        
        if not (s_d > 0 and s_h > 0 and s_w > 0):
             raise ValueError(f"Input scan size {input_scan_size} with depth {depth} results in non-positive bottleneck dims: {(s_d,s_h,s_w)}")

        self.transformer_bottleneck = TransformerBottleneck(
            in_channels=bottleneck_in_channels,
            embed_dim=transformer_embed_dim,
            num_layers=transformer_layers,
            num_heads=transformer_heads,
            bottleneck_spatial_dims=(s_d, s_h, s_w),
            dropout=transformer_dropout
        )
        
        # Decoder path
        self.decoder_conv_modulelist = nn.ModuleList() # List of ModuleLists for X_i,j blocks
        self.upsamplers = nn.ModuleList() # For Up(X_{i+1, j-1})

        # Create upsamplers: one for each level transition.
        # self.upsamplers[i] upsamples from level i+1 to level i.
        for i in range(depth): # i from 0 to depth-1
            ch_from_level_below = transformer_embed_dim if (i + 1) == depth else encoder_output_channels[i+1]
            ch_to_level_current = encoder_output_channels[i]
            self.upsamplers.append(
                nn.ConvTranspose3d(ch_from_level_below, ch_to_level_current, kernel_size=2, stride=2)
            )

        # Create decoder convolutional blocks (X_i,j for j > 0)
        for i in range(depth): # Level index for X_i,j (0 to depth-1, e.g. X0,j, X1,j, X2,j, X3,j)
            level_i_decoder_blocks = nn.ModuleList()
            for j in range(1, depth - i + 1): # Dense block index (j in X_i,j, from 1 up to depth-i)
                # Inputs to X_i,j: (X_i,0...X_i,j-1) + Up(X_{i+1,j-1})
                # All X_i,k features (including upsampled) should have encoder_output_channels[i]
                num_concat_features = j + 1 # j from same level, 1 from upsampled
                in_ch_for_Xij = encoder_output_channels[i] * num_concat_features
                out_ch_for_Xij = encoder_output_channels[i]
                level_i_decoder_blocks.append(ConvBlock3D(in_ch_for_Xij, out_ch_for_Xij, use_se=use_se))
            self.decoder_conv_modulelist.append(level_i_decoder_blocks)
        
        # Classification Head
        final_decoder_out_channels = encoder_output_channels[0] # From X_0,depth
        self.classification_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((1,1,1)),
            nn.Flatten(),
            nn.Linear(final_decoder_out_channels, final_fc_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(final_fc_units, num_classes)
        )

    def forward(self, x):
        X_features = [[None for _ in range(self.depth + 1)] for _ in range(self.depth + 1)]

        # Encoder path (computes X_i,0)
        current_feature_map = x
        for i in range(self.depth + 1): # 0 to depth
            current_feature_map = self.encoder_blocks[i](current_feature_map)
            X_features[i][0] = current_feature_map
            if i < self.depth:
                current_feature_map = self.pools[i](current_feature_map)
        
        # Transformer Bottleneck acts on X_depth,0
        X_features[self.depth][0] = self.transformer_bottleneck(X_features[self.depth][0])

        # Decoder path (computes X_i,j for j > 0)
        # i: level index (from depth-1 down to 0)
        # j: dense block index in skip connections (from 1 up to depth-i)
        for i in range(self.depth - 1, -1, -1): # Level: depth-1, depth-2, ..., 0
            for j in range(1, self.depth - i + 1): # j in X_i,j: 1, 2, ..., (depth-i)
                # Inputs from same level (X_i,0 ... X_i,j-1)
                inputs_same_level = [X_features[i][k] for k in range(j)]
                
                # Upsampled input from level below (X_{i+1, j-1})
                # self.upsamplers[i] upsamples from level i+1 to i.
                feature_from_below = X_features[i+1][j-1]
                upsampled_input = self.upsamplers[i](feature_from_below)
                
                # Ensure spatial dimensions match for concatenation (target is X_i,0's shape)
                target_spatial_size = X_features[i][0].shape[2:]
                if upsampled_input.shape[2:] != target_spatial_size:
                     upsampled_input = F.interpolate(upsampled_input, size=target_spatial_size, mode='trilinear', align_corners=False)
                
                combined_inputs = torch.cat(inputs_same_level + [upsampled_input], dim=1)
                
                # self.decoder_conv_modulelist[i] is the ModuleList for level i.
                # self.decoder_conv_modulelist[i][j-1] is the ConvBlock for X_i,j.
                X_features[i][j] = self.decoder_conv_modulelist[i][j-1](combined_inputs)

        # Final output for classification is from X_0,depth
        final_decoder_output = X_features[0][self.depth]
        
        logits = self.classification_head(final_decoder_output)
        return logits
    
# %% Imports
import os
import glob
import time
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Use standard tqdm if not in a notebook environment
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
import random
import scipy.ndimage
from skimage.measure import label as skimage_label, regionprops
from skimage.morphology import disk, binary_closing
from skimage.segmentation import clear_border
import scipy.ndimage as ndi

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F # Needed for UNet++ model
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, precision_recall_curve, auc, f1_score,
                             precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay)

# %% Configuration
# --- MODIFY THESE PATHS ---
DSB_PATH = r'F:/DSB3/stage1' # Base directory of DSB 2017 Stage 1 scans
DSB_LABELS_CSV = r'F:\DSB3\stage1_labels.csv' # Path to DSB patient cancer labels CSV
# --- Path for preprocessed data (MUST match the previous run) ---
PREPROCESSED_DSB_PATH = './preprocessed_dsb_50_each_64cube/'
MODEL_OUTPUT_DIR = './model_unetpp_se_transformer_64cube/' # New directory for this specific model's outputs


# Preprocessing & Model Params
TARGET_SPACING = [1.5, 1.5, 1.5]
FINAL_SCAN_SIZE = (64, 64, 64) # MUST match the preprocessed data
CLIP_BOUND_HU = [-1000.0, 400.0]
PIXEL_MEAN = 0.25

# Training Params
NUM_CLASSES = 1 # For BCEWithLogitsLoss
# !!! VERY IMPORTANT: Reduce BATCH_SIZE for the complex model !!!
BATCH_SIZE = 2  # Start low (e.g., 1 or 2) and increase if memory allows
LEARNING_RATE = 1e-5 # Start with a potentially smaller LR for complex models
EPOCHS = 50 # Reduce epochs initially to test, increase later if needed (was 150)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Scan limit per class (used during initial data selection phase)
SCAN_LIMIT_PER_CLASS = 50

# Ensure output directories exist
os.makedirs(PREPROCESSED_DSB_PATH, exist_ok=True)
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True) # Create specific dir for this model

# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# %% --- Data Loading and Selection (Verification Stage) ---
# This section assumes preprocessing was already done and verifies the files exist.
# It re-reads the labels and identifies the successfully preprocessed files.

print(f"--- Verifying Preprocessed Data in: {PREPROCESSED_DSB_PATH} ---")

# --- Load Labels ---
try:
    dsb_labels_df = pd.read_csv(DSB_LABELS_CSV)
    dsb_labels_df = dsb_labels_df.rename(columns={'id': 'patient_id'})
    patient_labels_all = dsb_labels_df.set_index('patient_id')['cancer'].to_dict()
    print(f"Loaded {len(patient_labels_all)} total DSB patient labels.")
except Exception as e:
    raise SystemExit(f"ERROR: Failed to load labels CSV: {e}")

# --- Find existing preprocessed files ---
existing_files = glob.glob(os.path.join(PREPROCESSED_DSB_PATH, "*.npz"))
if not existing_files:
    raise SystemExit(f"ERROR: No preprocessed .npz files found in {PREPROCESSED_DSB_PATH}. Run the preprocessing step first.")

# Extract patient IDs from filenames and filter based on available labels
available_processed_ids = []
for f_path in existing_files:
    p_id = os.path.basename(f_path).replace('.npz', '')
    if p_id in patient_labels_all:
        available_processed_ids.append(p_id)
    else:
        print(f"Warning: Preprocessed file found for patient {p_id}, but no label exists. Skipping.")

print(f"Found {len(available_processed_ids)} preprocessed scans with labels.")

# --- Use only the available processed IDs ---
final_patient_list = available_processed_ids
if not final_patient_list:
    raise SystemExit("No usable preprocessed scans found. Cannot continue.")

# Filter labels dictionary
patient_labels = {pid: patient_labels_all[pid] for pid in final_patient_list}
print(f"Final patient count for training/validation: {len(final_patient_list)}")
final_cancer_count = sum(1 for pid in final_patient_list if patient_labels[pid] == 1)
final_non_cancer_count = len(final_patient_list) - final_cancer_count
print(f"  Cancerous: {final_cancer_count}, Non-Cancerous: {final_non_cancer_count}")


# %% --- Preprocessing Functions (Required for Dataset Class checks, but not executed again) ---
# Keep these definitions available in case the Dataset class needs them implicitly,
# but the main preprocessing execution loop is skipped as we assume data exists.
def load_scan_series(dicom_folder_path): # Dummy definition if not needed elsewhere
    pass
def resample(image, original_spacing, new_spacing=TARGET_SPACING): pass
def get_segmented_lungs(im_slice, hu_threshold=-320): pass
def normalize_hu(image, clip_bounds=CLIP_BOUND_HU): pass
def zero_center(image, pixel_mean=PIXEL_MEAN): pass
def resize_scan_to_target(image, target_shape=FINAL_SCAN_SIZE): pass
def preprocess_scan_dsb(patient_id, input_base_path, output_base_path, force_preprocess=False): pass
print("Preprocessing functions defined (but execution skipped as data should exist).")


# %% --- Dataset and DataLoader ---

class PatientLevelDataset(Dataset):
    def __init__(self, patient_ids, labels_dict, preprocessed_path):
        self.patient_ids = patient_ids
        self.labels_dict = labels_dict
        self.preprocessed_path = preprocessed_path
        self.target_size = FINAL_SCAN_SIZE # Store target size for checks

    def __len__(self): return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        label = self.labels_dict[patient_id]
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")
        try:
            with np.load(scan_path) as npz_data:
                image = npz_data['image']
            # Verify shape
            if image.shape != self.target_size:
                print(f"ERROR: Shape mismatch for {patient_id}. Expected {self.target_size}, got {image.shape}. Returning error data.")
                dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
                return dummy, torch.tensor(-1, dtype=torch.float32) # Error label

            image_tensor = torch.from_numpy(image).float().unsqueeze(0) # Add channel dim
            label_tensor = torch.tensor(label, dtype=torch.float32)
            return image_tensor, label_tensor
        except FileNotFoundError:
             print(f"ERROR: File not found {scan_path}. Returning error data.")
             dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
             return dummy, torch.tensor(-1, dtype=torch.float32) # Error label
        except Exception as e:
            print(f"ERROR loading {patient_id}: {e}. Returning error data.")
            dummy = torch.zeros((1, *self.target_size), dtype=torch.float32)
            return dummy, torch.tensor(-1, dtype=torch.float32) # Error label

# --- Split Data (Train/Validation) ---
train_ids, val_ids = train_test_split(
    final_patient_list,
    test_size=0.2,
    random_state=SEED,
    stratify=[patient_labels[pid] for pid in final_patient_list] # Stratify
)
print(f"\nTraining patients: {len(train_ids)}")
print(f"Validation patients: {len(val_ids)}")

# --- Create Datasets and DataLoaders ---
train_dataset = PatientLevelDataset(train_ids, patient_labels, PREPROCESSED_DSB_PATH)
val_dataset = PatientLevelDataset(val_ids, patient_labels, PREPROCESSED_DSB_PATH)

NUM_WORKERS = 0 # Set to 0 for easier debugging
print(f"Using {NUM_WORKERS} workers for DataLoader.")
# Adjust batch size here if needed
print(f"Using Batch Size: {BATCH_SIZE}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())

# --- Check DataLoader Output ---
try:
    print("\nChecking DataLoader output...")
    if len(train_loader) > 0:
        sample_batch, sample_labels = next(iter(train_loader))
        print(f"Sample batch shape: {sample_batch.shape}") # Should be [B, 1, 64, 64, 64]
        print(f"Sample labels shape: {sample_labels.shape}")
        print(f"Sample labels: {sample_labels}")
        if torch.any(sample_labels == -1): print("WARNING: Error labels (-1) detected in first batch.")
    else: print("Train loader empty.")
except Exception as e: print(f"Error checking DataLoader: {e}")


# %% --- Model Definition (UNetPlusPlus_SE_Transformer) ---
# --- Building Blocks ---

class SEBlock3D(nn.Module):
    """3D Squeeze-and-Excitation Block."""
    def __init__(self, channels, reduction=16):
        super(SEBlock3D, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool3d(1)
        self.excitation = nn.Sequential(
            nn.Conv3d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.squeeze(x)
        y = self.excitation(y)
        return x * y

class ConvBlock3D(nn.Module):
    """Conv -> BN -> ReLU -> Conv -> BN -> ReLU (+ optional SE)."""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, use_se=True, se_reduction=16):
        super(ConvBlock3D, self).__init__()
        self.use_se = use_se
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        if self.use_se:
            self.se = SEBlock3D(out_channels, reduction=se_reduction)
    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        if self.use_se:
            x = self.se(x)
        return x

class TransformerEncoderLayer3D(nn.Module):
    """Single layer of a 3D Transformer Encoder (operates on sequence)."""
    def __init__(self, embed_dim, num_heads, ff_dim_factor=4, dropout=0.1):
        super(TransformerEncoderLayer3D, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * ff_dim_factor), nn.GELU(),
            nn.Dropout(dropout), nn.Linear(embed_dim * ff_dim_factor, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        attn_output, _ = self.attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)
        ffn_output = self.ffn(src)
        src = src + self.dropout2(ffn_output)
        src = self.norm2(src)
        return src

class TransformerBottleneck(nn.Module):
    """Transformer Bottleneck: projects feature map to sequence, applies transformer, reshapes back."""
    def __init__(self, in_channels, embed_dim, num_layers, num_heads, bottleneck_spatial_dims, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.bottleneck_spatial_dims = bottleneck_spatial_dims
        num_patches = bottleneck_spatial_dims[0] * bottleneck_spatial_dims[1] * bottleneck_spatial_dims[2]
        self.patch_projection = nn.Conv3d(in_channels, embed_dim, kernel_size=1) if in_channels != embed_dim else nn.Identity()
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        self.dropout_pos = nn.Dropout(dropout)
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer3D(embed_dim, num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.norm_out = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.patch_projection(x)
        b, e, d, h, w = x.shape
        if (d,h,w) != self.bottleneck_spatial_dims: raise ValueError(f"Spatial dim mismatch. Expected {self.bottleneck_spatial_dims}, got {(d,h,w)}")
        if e != self.embed_dim: raise ValueError(f"Embed dim mismatch. Expected {self.embed_dim}, got {e}")
        x = x.flatten(2).transpose(1, 2) # (B, N_patches, E_dim)
        x = x + self.pos_embed
        x = self.dropout_pos(x)
        for layer in self.transformer_layers: x = layer(x)
        x = self.norm_out(x)
        x = x.transpose(1, 2).view(b, e, d, h, w) # Reshape back
        return x

# --- Main Model: UNetPlusPlus_SE_Transformer ---
class UNetPlusPlus_SE_Transformer(nn.Module):
    def __init__(self, input_scan_size, in_channels=1, num_classes=1, initial_filters=16, depth=4,
                 use_se=True, transformer_embed_dim=256, transformer_layers=2,
                 transformer_heads=8, transformer_dropout=0.1, final_fc_units=128):
        super(UNetPlusPlus_SE_Transformer, self).__init__()
        if not (len(input_scan_size) == 3 and all(isinstance(s, int) for s in input_scan_size)):
            raise ValueError("input_scan_size must be tuple (D, H, W)")
        self.depth = depth
        self.use_se = use_se
        nf = initial_filters

        # Encoder (X_i,0)
        self.pools = nn.ModuleList()
        self.encoder_blocks = nn.ModuleList()
        encoder_output_channels = []
        current_channels_enc = in_channels
        for i in range(depth + 1):
            out_ch_enc = nf * (2**i)
            self.encoder_blocks.append(ConvBlock3D(current_channels_enc, out_ch_enc, use_se=use_se))
            encoder_output_channels.append(out_ch_enc)
            if i < depth: self.pools.append(nn.MaxPool3d(2, 2))
            current_channels_enc = out_ch_enc

        # Transformer Bottleneck
        bottleneck_in_channels = encoder_output_channels[-1]
        s_d, s_h, s_w = (input_scan_size[i] // (2**depth) for i in range(3))
        if not all(s > 0 for s in (s_d, s_h, s_w)): raise ValueError(f"Non-positive bottleneck dims: {(s_d,s_h,s_w)}")
        self.transformer_bottleneck = TransformerBottleneck(
            in_channels=bottleneck_in_channels, embed_dim=transformer_embed_dim, num_layers=transformer_layers,
            num_heads=transformer_heads, bottleneck_spatial_dims=(s_d, s_h, s_w), dropout=transformer_dropout)

        # Decoder
        self.decoder_conv_modulelist = nn.ModuleList()
        self.upsamplers = nn.ModuleList()
        for i in range(depth): # Upsamplers from level i+1 to i
            ch_from_below = transformer_embed_dim if (i + 1) == depth else encoder_output_channels[i+1]
            ch_to_current = encoder_output_channels[i]
            self.upsamplers.append(nn.ConvTranspose3d(ch_from_below, ch_to_current, kernel_size=2, stride=2))

        for i in range(depth): # Decoder blocks X_i,j (j>0)
            level_i_decoder_blocks = nn.ModuleList()
            for j in range(1, depth - i + 1):
                num_concat = j + 1
                in_ch_Xij = encoder_output_channels[i] * num_concat
                out_ch_Xij = encoder_output_channels[i]
                level_i_decoder_blocks.append(ConvBlock3D(in_ch_Xij, out_ch_Xij, use_se=use_se))
            self.decoder_conv_modulelist.append(level_i_decoder_blocks)

        # Classification Head
        final_decoder_out_channels = encoder_output_channels[0] # From X_0,depth
        self.classification_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((1,1,1)), nn.Flatten(),
            nn.Linear(final_decoder_out_channels, final_fc_units), nn.ReLU(inplace=True),
            nn.Dropout(0.5), nn.Linear(final_fc_units, num_classes) )

    def forward(self, x):
        X_features = [[None] * (self.depth + 1) for _ in range(self.depth + 1)]
        # Encoder
        current = x
        for i in range(self.depth + 1):
            X_features[i][0] = self.encoder_blocks[i](current)
            if i < self.depth: current = self.pools[i](X_features[i][0])
            else: current = X_features[i][0] # Last encoder output
        # Bottleneck
        X_features[self.depth][0] = self.transformer_bottleneck(X_features[self.depth][0])
        # Decoder
        for i in range(self.depth - 1, -1, -1): # Level i
            for j in range(1, self.depth - i + 1): # Dense block j at level i (X_i,j)
                inputs_same_level = [X_features[i][k] for k in range(j)]
                upsampled_input = self.upsamplers[i](X_features[i+1][j-1])
                target_spatial = X_features[i][0].shape[2:] # Match X_i,0 shape
                if upsampled_input.shape[2:] != target_spatial:
                    upsampled_input = F.interpolate(upsampled_input, size=target_spatial, mode='trilinear', align_corners=False)
                combined = torch.cat(inputs_same_level + [upsampled_input], dim=1)
                X_features[i][j] = self.decoder_conv_modulelist[i][j-1](combined)
        # Final output from X_0,depth
        logits = self.classification_head(X_features[0][self.depth])
        return logits

# %% --- Instantiate the UNetPlusPlus_SE_Transformer Model ---

print("\n--- Instantiating UNetPlusPlus_SE_Transformer Model ---")

# Define hyperparameters for the UNet++ Transformer model
# Adjust these based on your GPU memory and dataset size. These are starting points.
unetpp_initial_filters = 16     # Filters in the first layer (e.g., 16)
unetpp_depth = 3                # Number of pooling layers (e.g., 3 for 64->32->16->8)
                                # Bottleneck spatial size: 64 / (2^3) = 8x8x8
unetpp_transformer_embed_dim = unetpp_initial_filters * (2**unetpp_depth) # Match bottleneck channels (16 * 8 = 128)
unetpp_transformer_layers = 1   # Number of transformer layers (e.g., 1 or 2)
unetpp_transformer_heads = 4    # Number of attention heads (must divide embed_dim, e.g., 4)
unetpp_final_fc_units = 64      # Size of the final FC layer before output (e.g., 64)

# Clear any previous model from memory (important if running cells multiple times)
if 'patient_model' in locals():
    del patient_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Cleared previous model instance.")

try:
    patient_model = UNetPlusPlus_SE_Transformer(
        input_scan_size=FINAL_SCAN_SIZE,       # (64, 64, 64)
        in_channels=1,
        num_classes=NUM_CLASSES,               # 1 for BCEWithLogitsLoss
        initial_filters=unetpp_initial_filters,
        depth=unetpp_depth,
        use_se=True,                           # Use SE blocks
        transformer_embed_dim=unetpp_transformer_embed_dim,
        transformer_layers=unetpp_transformer_layers,
        transformer_heads=unetpp_transformer_heads,
        transformer_dropout=0.1,
        final_fc_units=unetpp_final_fc_units
    ).to(DEVICE)

    print(f"Instantiated UNetPlusPlus_SE_Transformer with:")
    print(f"  Initial Filters: {unetpp_initial_filters}, Depth: {unetpp_depth}")
    print(f"  Transformer Embed Dim: {unetpp_transformer_embed_dim}, Layers: {unetpp_transformer_layers}, Heads: {unetpp_transformer_heads}")
    print(f"  Final FC Units: {unetpp_final_fc_units}")
    # print(patient_model) # Optional: Print model structure (can be very long)

    # Optional: Count parameters to gauge model size
    total_params = sum(p.numel() for p in patient_model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params:,}")

    # Optional: Dry run with a dummy batch (on the specified device)
    # This helps catch CUDA errors or shape mismatches before training.
    print("Performing a quick model check with a dummy batch...")
    _dummy_check_batch_size = 1 # Use 1 to minimize memory check impact
    dummy_input = torch.randn(_dummy_check_batch_size, 1, *FINAL_SCAN_SIZE).to(DEVICE)
    with torch.no_grad():
        output = patient_model(dummy_input)
    print(f"Model check successful. Output shape: {output.shape}") # Should be [1, 1]

except Exception as e:
    print(f"\nERROR during model instantiation or check: {e}")
    print("Check hyperparameters, input_scan_size, and available memory.")
    import traceback
    traceback.print_exc()
    # Stop execution if model instantiation fails
    raise SystemExit("Model instantiation failed.")


# %% --- Loss and Optimizer ---

# Calculate positive weight based on the *training* set composition
train_labels_list = [patient_labels[pid] for pid in train_ids]
count_0 = train_labels_list.count(0)
count_1 = train_labels_list.count(1)
pos_weight_tensor = torch.tensor([1.0], device=DEVICE) # Default
if count_1 > 0 and count_0 > 0:
    pos_weight_val = count_0 / count_1
    print(f"Calculated positive weight for training subset ({count_1} pos / {count_0} neg): {pos_weight_val:.4f}")
    pos_weight_tensor = torch.tensor([pos_weight_val], device=DEVICE)
elif count_1 == 0 and count_0 > 0: print("Warning: Training set only contains negative samples.")
elif count_0 == 0 and count_1 > 0: print("Warning: Training set only contains positive samples.")
else: print("Warning: Training set appears empty or invalid. Using default pos_weight=1.")

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
# Use Adam optimizer, potentially with adjusted LR
print(f"Using Learning Rate: {LEARNING_RATE}")
optimizer = optim.Adam(patient_model.parameters(), lr=LEARNING_RATE) # Use the defined LEARNING_RATE
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


# %% --- Training and Validation Functions (Unchanged Logic) ---
def train_one_epoch_patient(model, dataloader, criterion, optimizer, device, scaler):
    model.train(); running_loss = 0.0; total_samples = 0; correct_predictions = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for i, (inputs, labels) in enumerate(progress_bar):
        valid_indices = labels != -1
        if not torch.any(valid_indices): continue
        inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device)

        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=scaler.is_enabled()):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print(f"NaN loss encountered during training batch {i}! Skipping batch.")
            optimizer.zero_grad() # Zero grad again before skipping
            continue

        scaler.scale(loss).backward()
        # Optional: Gradient clipping can help stabilize training for complex models
        # scaler.unscale_(optimizer) # Unscale gradients before clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        correct_predictions += (preds == labels.bool()).sum().item()
        progress_bar.set_postfix(loss=loss.item())

    if total_samples == 0: return 0.0, 0.0
    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

def validate_patient(model, dataloader, criterion, device):
    model.eval(); running_loss = 0.0; total_samples = 0; all_preds_proba = []; all_labels = []
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating", leave=False)
        for inputs, labels in progress_bar:
            valid_indices = labels != -1
            if not torch.any(valid_indices): continue
            inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].to(device)

            with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
                 outputs = model(inputs)
                 # Ensure labels have channel dim for loss calculation consistency
                 loss = criterion(outputs, labels.unsqueeze(1))

            if torch.isnan(loss):
                 print("NaN loss encountered during validation! Skipping batch contribution.")
                 continue

            running_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
            all_preds_proba.extend(torch.sigmoid(outputs).cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    if total_samples == 0: return 0.0, np.array([]), np.array([])
    val_loss = running_loss / total_samples
    return val_loss, np.array(all_labels), np.array(all_preds_proba)


# %% --- Training Loop ---
print(f"\n--- Starting Training: UNetPlusPlus_SE_Transformer Model ---")
print(f"Dataset size: {len(final_patient_list)} scans ({len(train_ids)} train, {len(val_ids)} val)")
print(f"Epochs: {EPOCHS}, Batch Size: {BATCH_SIZE}, Learning Rate: {LEARNING_RATE}")
print(f"Saving outputs to: {MODEL_OUTPUT_DIR}")

best_val_loss = float('inf')
train_losses, val_losses, train_accs, val_accs = [], [], [], [] # Add val_accs

# <<< --- Save model in the NEW model-specific directory --- >>>
MODEL_SAVE_PATH = os.path.join(MODEL_OUTPUT_DIR, "unetpp_se_transformer_64cube_best.pth")

if torch.cuda.is_available(): torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch_patient(patient_model, train_loader, criterion, optimizer, DEVICE, scaler)
    val_loss, val_labels_epoch, val_preds_proba_epoch = validate_patient(patient_model, val_loader, criterion, DEVICE)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)

    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time

    # Calculate validation accuracy for reporting
    val_acc_epoch = 0.0
    if len(val_labels_epoch) > 0:
        val_preds_binary_epoch = (val_preds_proba_epoch > 0.5).astype(int)
        val_acc_epoch = accuracy_score(val_labels_epoch, val_preds_binary_epoch)
        val_accs.append(val_acc_epoch) # Store validation accuracy
    else:
        val_accs.append(0.0) # Append 0 if validation was empty

    print(f"Epoch {epoch+1} Summary: Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc_epoch:.4f}")

    # Save best model based on validation loss
    # Ensure validation produced results before saving
    if val_loss < best_val_loss and len(val_labels_epoch) > 0:
        best_val_loss = val_loss
        try:
            torch.save(patient_model.state_dict(), MODEL_SAVE_PATH)
            print(f"  Best model saved to {MODEL_SAVE_PATH} (Val Loss: {best_val_loss:.4f})")
        except Exception as e:
            print(f"Error saving model: {e}")
    elif len(val_labels_epoch) == 0:
        print("  Skipping save, validation set was empty this epoch.")


    if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache per epoch

print("\n--- UNetPlusPlus_SE_Transformer Model Training Finished ---")


# %% --- Plot Training History ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, EPOCHS + 1), train_losses, label='Train Loss')
plt.plot(range(1, EPOCHS + 1), val_losses, label='Val Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('U-Net++ SE Transformer Loss'); plt.legend(); plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, EPOCHS + 1), train_accs, label='Train Acc')
plt.plot(range(1, EPOCHS + 1), val_accs, label='Val Acc') # Plot validation accuracy
plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('U-Net++ SE Transformer Accuracy'); plt.legend(); plt.grid(True)

plt.tight_layout()
# Save the plot to the model-specific directory
plot_save_path = os.path.join(MODEL_OUTPUT_DIR, "unetpp_se_transformer_training_curves.png")
plt.savefig(plot_save_path)
print(f"Training curves saved to {plot_save_path}")
plt.show()


# %% --- Model Evaluation ---
print("\n--- Evaluating UNetPlusPlus_SE_Transformer Model on Validation Set ---")

# Load the best model for evaluation
best_model_loaded = False
if os.path.exists(MODEL_SAVE_PATH):
    try:
        # Re-initialize model structure before loading state_dict
        # Ensure hyperparameters match the saved model
        patient_model_eval = UNetPlusPlus_SE_Transformer(
            input_scan_size=FINAL_SCAN_SIZE, in_channels=1, num_classes=NUM_CLASSES,
            initial_filters=unetpp_initial_filters, depth=unetpp_depth, use_se=True,
            transformer_embed_dim=unetpp_transformer_embed_dim, transformer_layers=unetpp_transformer_layers,
            transformer_heads=unetpp_transformer_heads, transformer_dropout=0.1,
            final_fc_units=unetpp_final_fc_units
        ).to(DEVICE)
        patient_model_eval.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best model from {MODEL_SAVE_PATH} for evaluation.")
        patient_model_eval.eval() # Set to evaluation mode
        best_model_loaded = True
    except Exception as e:
        print(f"Could not load best model from {MODEL_SAVE_PATH}: {e}. Evaluating last epoch model instead.")
        # Ensure the current model is in eval mode
        patient_model.eval()
        patient_model_eval = patient_model # Use the model from the end of training
else:
    print(f"Best model file not found at {MODEL_SAVE_PATH}. Evaluating model from the end of training.")
    patient_model.eval() # Set to evaluation mode
    patient_model_eval = patient_model

# Perform validation using the loaded/final model
val_loss_final, final_val_labels, final_val_preds_proba = validate_patient(patient_model_eval, val_loader, criterion, DEVICE)

if len(final_val_labels) == 0:
    print("No valid validation predictions were generated. Cannot evaluate metrics.")
else:
    print(f"\nFinal Validation Loss (Best Model={best_model_loaded}): {val_loss_final:.4f}")

    # Calculate metrics using a 0.5 threshold
    final_val_preds_binary = (final_val_preds_proba > 0.5).astype(int)

    accuracy = accuracy_score(final_val_labels, final_val_preds_binary)
    precision = precision_score(final_val_labels, final_val_preds_binary, zero_division=0)
    recall = recall_score(final_val_labels, final_val_preds_binary, zero_division=0)
    f1 = f1_score(final_val_labels, final_val_preds_binary, zero_division=0)

    auc_roc = float('nan')
    if len(np.unique(final_val_labels)) > 1:
         try: auc_roc = roc_auc_score(final_val_labels, final_val_preds_proba)
         except ValueError as e: print(f"AUC-ROC Calculation Error: {e}")
    else: print("AUC-ROC cannot be calculated: Only one class present in validation labels.")

    print("\n--- Final Validation Metrics (U-Net++ SE Transformer) ---")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print(f"AUC-ROC:   {auc_roc:.4f if not np.isnan(auc_roc) else 'N/A'}") # Handle NaN display

    print("\nClassification Report (U-Net++ SE Transformer):")
    target_names = ['Non-Cancer (0)', 'Cancer (1)']
    print(classification_report(final_val_labels.astype(int), final_val_preds_binary, target_names=target_names, zero_division=0))

    print("\nConfusion Matrix (U-Net++ SE Transformer):")
    try:
        cm = confusion_matrix(final_val_labels.astype(int), final_val_preds_binary, labels=[0, 1])
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
        disp.plot(cmap=plt.cm.Blues)
        cm_save_path = os.path.join(MODEL_OUTPUT_DIR, "unetpp_se_transformer_confusion_matrix.png")
        plt.savefig(cm_save_path)
        print(f"Confusion matrix saved to {cm_save_path}")
        plt.show()
    except Exception as e: print(f"Error displaying/saving confusion matrix: {e}")

    if not np.isnan(auc_roc):
        fpr, tpr, _ = roc_curve(final_val_labels, final_val_preds_proba)
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve (U-Net++ SE Transformer)')
        plt.legend(loc="lower right"); plt.grid(True)
        roc_save_path = os.path.join(MODEL_OUTPUT_DIR, "unetpp_se_transformer_roc_curve.png")
        plt.savefig(roc_save_path)
        print(f"ROC curve saved to {roc_save_path}")
        plt.show()

print("\n--- Evaluation Complete ---")

Using device: cuda
--- Loading Data and Selecting up to 50 Scans Per Class ---
DSB Scans path: C:\Users\rouaa\Documents\Final_Pneumatect\Stages
DSB Labels CSV: C:\Users\rouaa\Documents\Final_Pneumatect\stage1_labels.csv
Loaded 1595 total DSB patient labels.
Original label distribution:
 cancer
0    1176
1     419
Name: count, dtype: int64
Found 98 potential patient scan folders.
Found 98 patient IDs with both labels and scan folders.
Available Cancerous scans with labels: 27
Available Non-Cancerous scans with labels: 71
Selected 27 Cancerous scans.
Selected 50 Non-Cancerous scans.
Total scans selected for preprocessing: 77

Starting preprocessing for 77 selected scans...


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html