In [2]:
# %% [markdown]
# # End-to-End Lung Nodule Detection and Classification Pipeline
#
# This notebook outlines a cascade approach for lung nodule analysis:
# 1. **Segmentation:** A 3D U-Net (nnU-Net inspired) segments potential nodule regions.
# 2. **Detection & Proposal:** Sliding window on segmentation output to propose nodule candidates.
# 3. **Classification:** A 3D DenseNet121 (with SE blocks, pre-trained on MedicalNet) classifies candidate cubes.
#
# **Features:**
# - Hard Negative Mining for classifier training.
# - Focal Loss for imbalanced classification.
# - Mixed-precision training.
# - Cosine Learning Rate schedule.
# - Early Stopping based on validation AUC.
# - Evaluation: Detection Dice, Classification AUC, Sensitivity @ 95% Specificity.
# - Visualization: 3D nodule overlays, GradCAM.
# - Comparison: This cascade vs. a simpler direct 3D CNN classifier.

# %% [markdown]
# ## 1. Setup and Imports

# %%
import os
import glob
import time
import random
import json # For saving configs or reports
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Medical Imaging
import SimpleITK as sitk
# from radiomics import featureextractor # If you were to use radiomics for comparison

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

# Scikit-learn for metrics and utilities
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score,
    roc_curve, precision_recall_curve, auc, f1_score,
    precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay,
    # For Dice
    jaccard_score # IoU, Dice = 2*TP / (2*TP + FP + FN) or 2*Intersection / (Union + Intersection)
)
from skimage.measure import label as skimage_label, regionprops
from skimage.morphology import disk, binary_closing, ball # For 3D morphology
from skimage.segmentation import clear_border
import scipy.ndimage as ndi

# Visualization
# import itkwidgets # For interactive 3D plotting
# from IPython.display import display

# For nnU-Net like preprocessing/data handling (if not using full framework)
# from batchgenerators.utilities.file_and_folder_operations import *
# from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
# from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, GammaTransform
# from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
# from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor

# For Logging (optional, but recommended)
# import wandb
# from tensorboardX import SummaryWriter

# Custom modules (you'll create these)
# import utils
# import models_segmentation
# import models_classification
# import data_handling
# import training_loops
# import evaluation_metrics
# import visualization_tools

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# For reproducibility
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False # Can be True for speed if input sizes don't change

seed_everything(SEED)

# %% [markdown]
# ## 2. Configuration

# %%
CONFIG = {
    # --- Paths ---
    "data_dir": Path("/path/to/your/LIDC-IDRI-like/dataset"), # Contains CTs, masks, annotations
    "output_dir": Path("./output_nodule_pipeline"),
    "medicalnet_weights_path": Path("/path/to/your/MedicalNet_DenseNet121_weights.pth"), # Or ResNet50

    # --- General Preprocessing ---
    "target_spacing_seg": [1.5, 1.0, 1.0], # For segmentation model input
    "target_spacing_clf": [1.0, 1.0, 1.0], # For classifier input cubes
    "hu_clip_bounds": [-1000, 400],
    "norm_mean_std": {"mean": 0.25, "std": 0.25}, # Example, calculate from your data

    # --- Segmentation Model (nnU-Net inspired 3D U-Net) ---
    "seg_model_name": "Custom3DUNet",
    "seg_in_channels": 1, # CT
    "seg_out_channels": 1, # Nodule mask (binary)
    "seg_patch_size": [96, 128, 128], # D, H, W
    "seg_batch_size": 2,
    "seg_lr": 1e-4,
    "seg_epochs": 100, # Example
    "seg_loss_weights": {"dice": 0.6, "bce": 0.4}, # For combined loss

    # --- Nodule Detection/Proposal ---
    "seg_prob_threshold": 0.5, # Threshold for segmentation map
    "min_nodule_size_voxels": 20, # Minimum size for a connected component to be a nodule
    "sliding_window_stride": [32, 32, 32], # For dense proposal if needed, or use CCs directly

    # --- Candidate Cube Generation ---
    "clf_cube_size_raw": [48, 64, 64], # Before resizing to final classifier input
    "clf_cube_size_final": [32, 48, 48], # Final input size for classifier

    # --- Classification Model (3D DenseNet121 + SE) ---
    "clf_model_name": "DenseNet121_3D_SE",
    "clf_in_channels": 1,
    "clf_num_classes": 1, # Malignancy (binary or score) -> for sigmoid output
    "clf_use_medicalnet_pretrained": True,
    "clf_batch_size": 16, # Can be larger for smaller cubes
    "clf_lr": 1e-4, # Initial LR for classifier
    "clf_epochs": 75,
    "clf_focal_loss_alpha": 0.25, # Focal loss parameters
    "clf_focal_loss_gamma": 2.0,
    "clf_early_stopping_patience": 10, # For validation AUC
    "clf_cosine_lr_t_max": 75, # Total epochs for cosine schedule

    # --- Hard Negative Mining ---
    "hnm_start_epoch": 5, # Epoch to start HNM
    "hnm_ratio_neg_to_pos": 3, # e.g., 3 hard negatives for every 1 positive
    "hnm_num_hard_negatives_per_batch": lambda bs, ratio: int(bs * ratio / (1 + ratio)), # Dynamically calculate

    # --- Evaluation ---
    "eval_sensitivity_at_specificity": 0.95,

    # --- Logging & Reporting ---
    "use_wandb": False, # Set to True to use wandb
    "wandb_project_name": "lung_nodule_pipeline",

    # --- Comparison Model ---
    "comp_model_name": "Simple3DCNN_Classifier", # For comparison
    "comp_lr": 1e-4,
    "comp_epochs": 75,
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
(CONFIG["output_dir"] / "segmentation_models").mkdir(exist_ok=True)
(CONFIG["output_dir"] / "classification_models").mkdir(exist_ok=True)
(CONFIG["output_dir"] / "comparison_models").mkdir(exist_ok=True)
(CONFIG["output_dir"] / "visualizations").mkdir(exist_ok=True)
(CONFIG["output_dir"] / "reports").mkdir(exist_ok=True)

# Save config
with open(CONFIG["output_dir"] / "config.json", 'w') as f:
    # Path objects are not JSON serializable directly
    json.dump({k: str(v) if isinstance(v, Path) else v for k,v in CONFIG.items()}, f, indent=4)

print("Configuration saved.")

# %% [markdown]
# ## 3. Data Loading and Preprocessing
#
# This section involves:
# - Identifying patient scan folders.
# - Loading CT images, lung masks (if available, or segment them first), and nodule annotations.
# - Preprocessing:
#     - Resampling to target spacing.
#     - HU intensity windowing/clipping.
#     - Normalization.
#     - Creating 3D binary masks for nodules (from annotations) for segmentation training.
# - Splitting data into train, validation, (and optionally test) sets.

# %%
# Placeholder for LIDC-IDRI (or similar) data parsing utilities
# Assume annotations give nodule centroids, bounding boxes, and possibly malignancy labels.

def load_patient_data(patient_id, data_root_dir):
    """
    Loads CT, lung mask (optional), and nodule annotations for a patient.
    Returns:
        sitk_ct_image, sitk_lung_mask (or None), nodule_annotations_list
    """
    # ct_path = data_root_dir / patient_id / "ct_scan.mha" # Example path
    # lung_mask_path = data_root_dir / patient_id / "lung_mask.mha" # Example path
    # annotations_path = data_root_dir / patient_id / "annotations.xml" # Example for LIDC

    # sitk_ct_image = sitk.ReadImage(str(ct_path))
    # sitk_lung_mask = sitk.ReadImage(str(lung_mask_path)) if lung_mask_path.exists() else None
    # nodule_annotations_list = parse_lidc_xml(annotations_path) # You'd need this parser

    # Dummy example:
    print(f"Loading data for patient {patient_id} (dummy implementation)")
    # Create a dummy CT image
    dummy_ct_array = np.random.rand(128, 256, 256).astype(np.float32) * 1000 - 500 # Dummy HU values
    sitk_ct_image = sitk.GetImageFromArray(dummy_ct_array)
    sitk_ct_image.SetSpacing([1.0, 0.7, 0.7]) # Dummy spacing

    # Create a dummy lung mask
    dummy_lung_mask_array = np.zeros_like(dummy_ct_array, dtype=np.uint8)
    dummy_lung_mask_array[30:100, 50:200, 50:200] = 1 # Dummy lung region
    sitk_lung_mask = sitk.GetImageFromArray(dummy_lung_mask_array)
    sitk_lung_mask.SetSpacing(sitk_ct_image.GetSpacing())


    # Dummy nodule annotations: list of dicts
    # Each dict: {'centroid_world': [x,y,z], 'diameter_mm': d, 'malignancy': m (0-5 or binary)}
    # For segmentation, we'd convert these to voxel coordinates and create a mask
    nodule_annotations_list = [
        {'centroid_world': np.array(sitk_ct_image.TransformContinuousIndexToPhysicalPoint([60,100,100])),
         'diameter_mm': 10, 'malignancy': 4, 'id': 'nod1'},
        {'centroid_world': np.array(sitk_ct_image.TransformContinuousIndexToPhysicalPoint([70,150,150])),
         'diameter_mm': 5, 'malignancy': 1, 'id': 'nod2'}
    ]
    return sitk_ct_image, sitk_lung_mask, nodule_annotations_list

def preprocess_image(sitk_img, target_spacing, clip_bounds, norm_stats, is_mask=False):
    """ Resample, clip, normalize. """
    # Resample
    original_spacing = sitk_img.GetSpacing()
    original_size = sitk_img.GetSize()
    new_size = [
        int(round(osz * ospc / tspc))
        for osz, ospc, tspc in zip(original_size, original_spacing, target_spacing)
    ]
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(target_spacing)
    resampler.SetSize(new_size)
    resampler.SetOutputDirection(sitk_img.GetDirection())
    resampler.SetOutputOrigin(sitk_img.GetOrigin())
    resampler.SetTransform(sitk.Transform())
    resampler.SetInterpolator(sitk.sitkLinear if not is_mask else sitk.sitkNearestNeighbor)
    resampled_img = resampler.Execute(sitk_img)

    img_array = sitk.GetArrayFromImage(resampled_img).astype(np.float32)

    if not is_mask:
        # Clip HU
        img_array = np.clip(img_array, clip_bounds[0], clip_bounds[1])
        # Normalize (example: (x - min) / (max - min) then (x - mean) / std)
        img_array = (img_array - clip_bounds[0]) / (clip_bounds[1] - clip_bounds[0])
        img_array = (img_array - norm_stats["mean"]) / norm_stats["std"]
    else:
        img_array = img_array.astype(np.uint8) # Ensure mask is int

    return img_array # Returns numpy array

def create_nodule_segmentation_mask(ct_image_sitk, nodule_annotations, target_shape_voxels):
    """
    Creates a 3D binary mask for nodules based on annotations for a specific resampled CT.
    ct_image_sitk: The resampled SimpleITK image to get physical to voxel coordinate transforms.
    nodule_annotations: List of nodule dicts with 'centroid_world' and 'diameter_mm'.
    target_shape_voxels: The D,H,W shape of the numpy array for the mask.
    """
    nodule_mask_np = np.zeros(target_shape_voxels, dtype=np.uint8)
    target_spacing = ct_image_sitk.GetSpacing() # Spacing of the resampled CT

    for nod in nodule_annotations:
        centroid_world = nod['centroid_world']
        diameter_mm = nod['diameter_mm']
        radius_mm = diameter_mm / 2.0

        # Convert world centroid to voxel centroid in the resampled image
        centroid_voxel_continuous = ct_image_sitk.TransformPhysicalPointToContinuousIndex(centroid_world)
        centroid_voxel = [int(round(c)) for c in centroid_voxel_continuous]

        # Define bounding box in voxel coordinates
        # Convert radius in mm to radius in voxels for each dimension
        radius_voxels = [radius_mm / spc for spc in target_spacing]

        z_min = max(0, int(round(centroid_voxel[2] - radius_voxels[2]))) # SimpleITK index order is x,y,z
        z_max = min(target_shape_voxels[0] -1, int(round(centroid_voxel[2] + radius_voxels[2]))) # Numpy array is D,H,W so z is index 0
        y_min = max(0, int(round(centroid_voxel[1] - radius_voxels[1])))
        y_max = min(target_shape_voxels[1] -1, int(round(centroid_voxel[1] + radius_voxels[1])))
        x_min = max(0, int(round(centroid_voxel[0] - radius_voxels[0])))
        x_max = min(target_shape_voxels[2] -1, int(round(centroid_voxel[0] + radius_voxels[0])))

        # Create a spherical mask (approximate by drawing in a 3D grid)
        # This is a simplification; more accurate rasterization might be needed
        for z_idx in range(z_min, z_max + 1):
            for y_idx in range(y_min, y_max + 1):
                for x_idx in range(x_min, x_max + 1):
                    # Check if point (z_idx, y_idx, x_idx) is within the sphere
                    # Convert voxel indices back to continuous for distance check
                    dist_sq = ( ((z_idx - centroid_voxel_continuous[2]) * target_spacing[2])**2 +
                                ((y_idx - centroid_voxel_continuous[1]) * target_spacing[1])**2 +
                                ((x_idx - centroid_voxel_continuous[0]) * target_spacing[0])**2 )
                    if dist_sq <= radius_mm**2:
                        # Check bounds before assignment
                        if 0 <= z_idx < target_shape_voxels[0] and \
                           0 <= y_idx < target_shape_voxels[1] and \
                           0 <= x_idx < target_shape_voxels[2]:
                            nodule_mask_np[z_idx, y_idx, x_idx] = 1 # Numpy order D, H, W
    return nodule_mask_np


# --- Prepare lists of patient IDs ---
# all_patient_ids = [f.name for f in CONFIG["data_dir"].iterdir() if f.is_dir()]
all_patient_ids = [f"patient_{i:03d}" for i in range(20)] # Dummy patient IDs
train_ids, val_test_ids = train_test_split(all_patient_ids, test_size=0.3, random_state=SEED)
val_ids, test_ids = train_test_split(val_test_ids, test_size=0.5, random_state=SEED)

print(f"Train IDs: {len(train_ids)}, Val IDs: {len(val_ids)}, Test IDs: {len(test_ids)}")

# --- Example of processing one patient (for segmentation data) ---
# This would typically be done inside a Dataset class's __getitem__ or a preprocessing script

# For a patient_id in train_ids:
#   sitk_ct, sitk_lm, annotations = load_patient_data(patient_id, CONFIG["data_dir"])
#   # Preprocess CT for segmentation model
#   processed_ct_np = preprocess_image(sitk_ct, CONFIG["target_spacing_seg"],
#                                       CONFIG["hu_clip_bounds"], CONFIG["norm_mean_std"])
#
#   # Create a resampled sitk_ct to pass to nodule mask creation for coordinate transformation
#   # This is a bit redundant here, ideally preprocess_image returns sitk object or transform info
#   _ , resampled_size_seg = sitk.ResampleImageFilter().Compute μετα# ... (previous code)

#   # Calculate the size of the resampled CT image for segmentation
#   original_spacing_ct = sitk_ct.GetSpacing()
#   original_size_ct = sitk_ct.GetSize()
#   target_spacing_seg = CONFIG["target_spacing_seg"]
#   resampled_size_for_seg_mask_sitk = [
#       int(round(osz * ospc / tspc))
#       for osz, ospc, tspc in zip(original_size_ct, original_spacing_ct, target_spacing_seg)
#   ]
#   # Create a dummy resampled sitk_ct for coordinate transformation purposes
#   # (In a real pipeline, preprocess_image would ideally handle this or return necessary info)
#   resampler_ref = sitk.ResampleImageFilter()
#   resampler_ref.SetOutputSpacing(target_spacing_seg)
#   resampler_ref.SetSize(resampled_size_for_seg_mask_sitk)
#   resampler_ref.SetOutputDirection(sitk_ct.GetDirection())
#   resampler_ref.SetOutputOrigin(sitk_ct.GetOrigin())
#   resampled_ct_for_coords_sitk = resampler_ref.Execute(sitk_ct) # just need its geometry
#
#   # Create nodule segmentation mask based on the geometry of the resampled CT
#   # The shape of the numpy array from preprocess_image must match here
#   nodule_seg_mask_np = create_nodule_segmentation_mask(
#       resampled_ct_for_coords_sitk,
#       annotations,
#       target_shape_voxels=processed_ct_np.shape # (D, H, W)
#   )
#   # Save processed_ct_np and nodule_seg_mask_np for training/validation
#   # e.g., to CONFIG["output_dir"] / "preprocessed_seg" / f"{patient_id}_ct.npy"
#   #      CONFIG["output_dir"] / "preprocessed_seg" / f"{patient_id}_seg_mask.npy"

# %% [markdown]
# ## 4. Segmentation Model (3D nnU-Net Inspired U-Net)
#
# - Define a 3D U-Net architecture.
# - Define Dataset and DataLoader for segmentation.
# - Define loss function (e.g., Dice + BCE).
# - Training loop for segmentation.

# %%
# Placeholder for 3D U-Net definition (models_segmentation.py)
# e.g., using blocks like: Conv3D -> InstanceNorm3D -> LeakyReLU

class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class DownSample3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv3D(in_channels, out_channels)
        self.pool = nn.MaxPool3d(2, 2)
    def forward(self, x):
        skip_connection = self.conv(x)
        pooled = self.pool(skip_connection)
        return pooled, skip_connection

class UpSample3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Use ConvTranspose3d for upsampling
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv3D(in_channels, out_channels) # in_channels because of skip connection concat
    def forward(self, x1, x2): # x1 from previous layer, x2 is skip connection
        x1 = self.up(x1)
        # Pad if dimensions mismatch (common in U-Nets)
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class Custom3DUNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[32, 64, 128, 256]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Down part
        for feature in features:
            self.downs.append(DoubleConv3D(in_channels, feature))
            in_channels = feature

        # Up part
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose3d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv3D(feature * 2, feature)) # After concat with skip

        self.bottleneck = DoubleConv3D(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for i in range(len(self.downs)):
            x = self.downs[i](x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # Reverse for up-sampling

        for i in range(0, len(self.ups), 2): # Step by 2 (ConvTranspose, DoubleConv)
            x = self.ups[i](x) # Upsample
            skip_connection = skip_connections[i//2]
            # Pad if necessary
            if x.shape != skip_connection.shape:
                # print(f"Padding needed. x: {x.shape}, skip: {skip_connection.shape}")
                diffZ = skip_connection.size()[2] - x.size()[2]
                diffY = skip_connection.size()[3] - x.size()[3]
                diffX = skip_connection.size()[4] - x.size()[4]
                x = F.pad(x, [diffX // 2, diffX - diffX // 2,
                                diffY // 2, diffY - diffY // 2,
                                diffZ // 2, diffZ - diffZ // 2])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[i+1](concat_skip) # DoubleConv

        return self.final_conv(x)


class SegmentationDataset(Dataset):
    def __init__(self, patient_ids, config, preprocessed_data_dir, transform=None):
        self.patient_ids = patient_ids
        self.config = config
        self.preprocessed_data_dir = Path(preprocessed_data_dir)
        self.transform = transform # For patch extraction and augmentation

    def __len__(self):
        return len(self.patient_ids) # Or num_patches if patch-based

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        # ct_path = self.preprocessed_data_dir / f"{patient_id}_ct.npy"
        # mask_path = self.preprocessed_data_dir / f"{patient_id}_seg_mask.npy"
        # ct_array = np.load(ct_path)
        # mask_array = np.load(mask_path)

        # Dummy data for dataset
        ct_array = np.random.rand(*self.config["seg_patch_size"]).astype(np.float32)
        mask_array = (np.random.rand(*self.config["seg_patch_size"]) > 0.8).astype(np.uint8) # Sparse mask

        # Apply transforms (e.g., patching, nnU-Net style augmentations)
        # if self.transform:
        #     data_dict = self.transform(data={'data': ct_array[None], 'seg': mask_array[None]}) # Add channel dim
        #     ct_array, mask_array = data_dict['data'][0], data_dict['seg'][0]
        # else: # Simple patch extraction or use full image if small enough
        # This needs to be adapted for patch-based training if images are large
        # For now, assume ct_array and mask_array are already patch-sized or full images ready for training
        pass

        ct_tensor = torch.from_numpy(ct_array).float().unsqueeze(0) # Add channel dim (B, C, D, H, W)
        mask_tensor = torch.from_numpy(mask_array).float().unsqueeze(0)

        return ct_tensor, mask_tensor

# --- Loss Function ---
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, pred_probs, target):
        pred_probs = torch.sigmoid(pred_probs) # If model outputs logits
        intersection = (pred_probs * target).sum(dim=(2,3,4)) # Sum over D, H, W
        union = pred_probs.sum(dim=(2,3,4)) + target.sum(dim=(2,3,4))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1. - dice.mean() # Average over batch

class CombinedSegLoss(nn.Module):
    def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth_dice=1e-6):
        super().__init__()
        self.dice_loss = DiceLoss(smooth=smooth_dice)
        self.bce_loss = nn.BCEWithLogitsLoss() # Takes logits directly
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
    def forward(self, pred_logits, target_mask):
        loss_dice = self.dice_loss(pred_logits, target_mask)
        loss_bce = self.bce_loss(pred_logits, target_mask)
        return self.dice_weight * loss_dice + self.bce_weight * loss_bce

# --- Training Loop for Segmentation ---
def train_segmentation_epoch(model, dataloader, optimizer, loss_fn, scaler, device):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Seg Training Epoch")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()): # Mixed precision
            predictions = model(data)
            loss = loss_fn(predictions, target)
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN/Inf loss in seg train: {loss.item()}. Skipping batch.")
            continue
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

def validate_segmentation_epoch(model, dataloader, loss_fn, device):
    model.eval()
    epoch_loss = 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Seg Validation Epoch"):
            data, target = data.to(device), target.to(device)
            with autocast(enabled=torch.cuda.is_available()):
                predictions = model(data)
                loss = loss_fn(predictions, target)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss in seg val: {loss.item()}. Skipping batch.")
                continue
            epoch_loss += loss.item()
            # Store predictions and targets for Dice calculation
            # For simplicity, assuming binary segmentation for Dice here
            all_preds.append(torch.sigmoid(predictions).cpu().numpy() > 0.5)
            all_targets.append(target.cpu().numpy())

    if not all_preds: return 0.0, 0.0 # Handle empty validation

    all_preds_np = np.concatenate(all_preds, axis=0).flatten()
    all_targets_np = np.concatenate(all_targets, axis=0).flatten()

    # Calculate Dice score for the entire validation set
    # Note: This is a flattened Dice. Voxel-wise. For object-level Dice, need connected components.
    intersection = np.sum(all_preds_np * all_targets_np)
    val_dice = (2. * intersection) / (np.sum(all_preds_np) + np.sum(all_targets_np) + 1e-6)

    return epoch_loss / len(dataloader), val_dice

# Initialize segmentation model, optimizer, loss
seg_model = Custom3DUNet(
    in_channels=CONFIG["seg_in_channels"],
    out_channels=CONFIG["seg_out_channels"]
).to(DEVICE)
seg_optimizer = optim.Adam(seg_model.parameters(), lr=CONFIG["seg_lr"])
seg_loss_fn = CombinedSegLoss(
    dice_weight=CONFIG["seg_loss_weights"]["dice"],
    bce_weight=CONFIG["seg_loss_weights"]["bce"]
)
seg_scaler = GradScaler(enabled=torch.cuda.is_available())
seg_lr_scheduler = CosineAnnealingLR(seg_optimizer, T_max=CONFIG["seg_epochs"])


# Dummy preprocessed data path for seg dataset
# In reality, you'd run the preprocessing step from Section 3 first
dummy_preprocessed_seg_dir = CONFIG["output_dir"] / "preprocessed_seg_dummy"
dummy_preprocessed_seg_dir.mkdir(exist_ok=True)

seg_train_dataset = SegmentationDataset(train_ids, CONFIG, dummy_preprocessed_seg_dir) # Pass actual path
seg_val_dataset = SegmentationDataset(val_ids, CONFIG, dummy_preprocessed_seg_dir)     # Pass actual path
seg_train_loader = DataLoader(seg_train_dataset, batch_size=CONFIG["seg_batch_size"], shuffle=True, num_workers=0)
seg_val_loader = DataLoader(seg_val_dataset, batch_size=CONFIG["seg_batch_size"], shuffle=False, num_workers=0)

print(f"\n--- Training Segmentation Model ({CONFIG['seg_model_name']}) ---")
best_val_dice_seg = -1.0
for epoch in range(CONFIG["seg_epochs"]):
    train_loss_seg = train_segmentation_epoch(seg_model, seg_train_loader, seg_optimizer, seg_loss_fn, seg_scaler, DEVICE)
    val_loss_seg, val_dice_seg = validate_segmentation_epoch(seg_model, seg_val_loader, seg_loss_fn, DEVICE)
    seg_lr_scheduler.step()

    print(f"Epoch {epoch+1}/{CONFIG['seg_epochs']}: Seg Train Loss: {train_loss_seg:.4f}, Seg Val Loss: {val_loss_seg:.4f}, Seg Val Dice: {val_dice_seg:.4f}")

    if val_dice_seg > best_val_dice_seg:
        best_val_dice_seg = val_dice_seg
        torch.save(seg_model.state_dict(), CONFIG["output_dir"] / "segmentation_models" / f"{CONFIG['seg_model_name']}_best.pth")
        print(f"  Saved best segmentation model with Val Dice: {best_val_dice_seg:.4f}")

print("Segmentation training finished.")


# %% [markdown]
# ## 5. Sliding-Window Detection & Candidate Proposal
#
# - Load trained segmentation model.
# - Iterate through (test/validation) scans.
# - Apply segmentation model to get probability maps.
# - Threshold probability maps.
# - Use connected components (`skimage.measure.label`) to identify discrete nodule candidates.
# - Filter candidates by size or other criteria.
# - Output bounding boxes or centroids of these candidates.

# %%
def get_nodule_candidates_from_segmentation(
    ct_sitk_original, # Original SITK CT for coordinate mapping
    seg_model,
    config,
    device):
    """
    Takes a CT scan, applies segmentation model, and returns candidate nodule info.
    """
    seg_model.eval()

    # 1. Preprocess CT for segmentation model input
    processed_ct_np = preprocess_image(
        ct_sitk_original,
        config["target_spacing_seg"],
        config["hu_clip_bounds"],
        config["norm_mean_std"]
    )
    # This processed_ct_np might need to be broken into patches if model expects patches
    # For simplicity, assume model can take the whole resampled volume (or it's handled internally)
    # If patch-based, you'd need a sliding window inference here.
    ct_tensor = torch.from_numpy(processed_ct_np).float().unsqueeze(0).unsqueeze(0).to(device) # B, C, D, H, W

    with torch.no_grad(), autocast(enabled=torch.cuda.is_available()):
        seg_logits = seg_model(ct_tensor)
        seg_probs_np = torch.sigmoid(seg_logits).squeeze().cpu().numpy() # D, H, W

    # 2. Threshold and get connected components
    binary_mask = (seg_probs_np > config["seg_prob_threshold"]).astype(np.uint8)
    labeled_mask, num_labels = skimage_label(binary_mask, connectivity=3, return_num=True) # 3D connectivity

    candidates = []
    # Create a resampled sitk_ct for coordinate transformation (geometry of seg_probs_np)
    original_spacing = ct_sitk_original.GetSpacing()
    original_size = ct_sitk_original.GetSize()
    target_spacing_seg = config["target_spacing_seg"]
    resampled_size_for_coords = [
        int(round(osz * ospc / tspc))
        for osz, ospc, tspc in zip(original_size, original_spacing, target_spacing_seg)
    ]
    resampler_ref = sitk.ResampleImageFilter()
    resampler_ref.SetOutputSpacing(target_spacing_seg)
    resampler_ref.SetSize(resampled_size_for_coords) # Should match seg_probs_np.shape if D,H,W order is consistent
    resampler_ref.SetOutputDirection(ct_sitk_original.GetDirection())
    resampler_ref.SetOutputOrigin(ct_sitk_original.GetOrigin())
    # Execute with a dummy image just to get the transform object, or use the resampled ct if available
    # Ensure the geometry matches seg_probs_np for correct coordinate transforms
    # For this dummy execution, assume resampled_ct_for_coords_sitk has this geometry
    resampled_ct_for_coords_sitk = resampler_ref.Execute(sitk.Image(processed_ct_np.shape[::-1], sitk.sitkFloat32)) # Use shape of seg_probs_np
    resampled_ct_for_coords_sitk.SetSpacing(target_spacing_seg)
    resampled_ct_for_coords_sitk.SetOrigin(ct_sitk_original.GetOrigin()) # This might need adjustment based on how resampling affects origin

    for i in range(1, num_labels + 1):
        props = regionprops(labeled_mask == i, intensity_image=seg_probs_np)
        if props: # Should always be one region for label i
            prop = props[0]
            if prop.area >= config["min_nodule_size_voxels"]:
                centroid_voxel_seg = prop.centroid # (z, y, x) for numpy array
                # Convert centroid from seg_probs_np voxel coords to world coords
                # seg_probs_np is (D,H,W), SimpleITK continuous index is (x,y,z)
                centroid_world = resampled_ct_for_coords_sitk.TransformContinuousIndexToPhysicalPoint(
                    (centroid_voxel_seg[2], centroid_voxel_seg[1], centroid_voxel_seg[0]) # x,y,z for SITK
                )
                candidates.append({
                    "centroid_voxel_seg": centroid_voxel_seg, # (z,y,x) in seg_probs_np space
                    "centroid_world": np.array(centroid_world),
                    "bbox_voxel_seg": prop.bbox, # (min_z, min_y, min_x, max_z, max_y, max_x)
                    "mean_intensity_seg_prob": prop.mean_intensity, # Avg seg prob in candidate
                    "id": f"cand_{i}"
                })
    return candidates, seg_probs_np # Return seg_probs for Dice eval if needed

# Load best segmentation model
best_seg_model_path = CONFIG["output_dir"] / "segmentation_models" / f"{CONFIG['seg_model_name']}_best.pth"
if best_seg_model_path.exists():
    seg_model.load_state_dict(torch.load(best_seg_model_path, map_location=DEVICE))
    print(f"Loaded best segmentation model from {best_seg_model_path}")
else:
    print("WARNING: Best segmentation model not found. Using last epoch model for proposals.")

# Example: Get candidates for one validation patient
# patient_to_test_seg = val_ids[0]
# sitk_ct_test, _, true_annotations_test = load_patient_data(patient_to_test_seg, CONFIG["data_dir"])
# proposed_nodules, seg_map_for_dice = get_nodule_candidates_from_segmentation(
#     sitk_ct_test, seg_model, CONFIG, DEVICE
# )
# print(f"Proposed {len(proposed_nodules)} nodules for patient {patient_to_test_seg}.")
# for nod in proposed_nodules[:2]: print(nod)

# %% [markdown]
# ## 6. Candidate Cube Generation for Classifier
#
# - For each proposed nodule candidate (and ground truth nodules):
#     - Crop a 3D cube around its centroid from the **original CT scan, resampled to classifier's target spacing**.
#     - Resize/pad this cube to the classifier's fixed input size (`clf_cube_size_final`).
#     - Store these cubes and their labels (malignancy for GT, pseudo-labels for proposals if needed).
#     - For Hard Negative Mining, proposals that don't overlap with GT nodules are initially "negatives".

# %%
def extract_cube(
    sitk_ct_original, # Original CT to crop from
    world_centroid,   # World coordinates of the center of the cube
    raw_cube_size_voxels_clf, # Physical size of cube to extract, in voxels at *classifier* spacing
    target_spacing_clf, # Classifier's target spacing
    final_cube_shape_clf, # Target D,H,W for classifier input
    config,
    is_mask=False # Not used here, but for consistency
    ):
    """
    Extracts a 3D cube, resamples it to classifier spacing, normalizes, and resizes.
    raw_cube_size_voxels_clf: (Depth, Height, Width) number of voxels to aim for at target_spacing_clf
                              This defines the *physical* size of the cube to extract.
    """
    # 1. Determine physical size of the cube based on raw_cube_size_voxels_clf and target_spacing_clf
    physical_size_D = raw_cube_size_voxels_clf[0] * target_spacing_clf[0]
    physical_size_H = raw_cube_size_voxels_clf[1] * target_spacing_clf[1]
    physical_size_W = raw_cube_size_voxels_clf[2] * target_spacing_clf[2]

    # 2. Create a resampling grid centered at world_centroid with target_spacing_clf
    #    and size defined by physical_size / target_spacing_clf (which is raw_cube_size_voxels_clf)
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(target_spacing_clf)
    resampler.SetSize(raw_cube_size_voxels_clf[::-1]) # SITK size is x,y,z

    # Calculate origin for the resampling grid
    # Center of the output cube in physical space is world_centroid
    # Origin = center - (size_physical / 2)
    # But SimpleITK origin is at the corner of the first voxel
    # So, origin = world_centroid - ( (size_voxels - 1) * spacing / 2 )
    # Or, more simply: if the output cube has 'N' voxels along an axis with 'spc' spacing,
    # its physical extent is N*spc. The origin for SITK resampler should make world_centroid
    # the center of this physical extent.
    origin_x = world_centroid[0] - (physical_size_W / 2.0) + (target_spacing_clf[2] / 2.0)
    origin_y = world_centroid[1] - (physical_size_H / 2.0) + (target_spacing_clf[1] / 2.0)
    origin_z = world_centroid[2] - (physical_size_D / 2.0) + (target_spacing_clf[0] / 2.0)
    resampler.SetOutputOrigin([origin_x, origin_y, origin_z])

    resampler.SetOutputDirection(sitk_ct_original.GetDirection()) # Assume same orientation
    resampler.SetInterpolator(sitk.sitkLinear)
    resampled_cube_sitk = resampler.Execute(sitk_ct_original)

    # 3. Convert to numpy, clip, normalize
    cube_np_raw_spacing = sitk.GetArrayFromImage(resampled_cube_sitk) # D, H, W
    cube_np_normalized = preprocess_image(
        resampled_cube_sitk, # Pass SITK image to reuse existing preprocess
        target_spacing=target_spacing_clf, # Already at this spacing, but preprocess_image expects it
        clip_bounds=config["hu_clip_bounds"],
        norm_stats=config["norm_mean_std"],
        is_mask=False
    ) # This will re-resample if target_spacing_clf is different, ensure it's the same.
      # Or, just do clip & norm on cube_np_raw_spacing directly.
      # For simplicity, let's assume preprocess_image called with same target_spacing doesn't resample again.
      # A cleaner way:
    # cube_np_raw_spacing = np.clip(cube_np_raw_spacing, config["hu_clip_bounds"][0], config["hu_clip_bounds"][1])
    # cube_np_raw_spacing = (cube_np_raw_spacing - config["hu_clip_bounds"][0]) / (config["hu_clip_bounds"][1] - config["hu_clip_bounds"][0])
    # cube_np_normalized = (cube_np_raw_spacing - config["norm_mean_std"]["mean"]) / config["norm_mean_std"]["std"]


    # 4. Resize/pad to final_cube_shape_clf (e.g., 32, 48, 48)
    # Scipy.ndimage.zoom for resizing
    current_shape = cube_np_normalized.shape
    zoom_factors = [f_dim / c_dim for f_dim, c_dim in zip(final_cube_shape_clf, current_shape)]
    final_cube_np = ndi.zoom(cube_np_normalized, zoom_factors, order=1, mode='nearest') # order=1 for linear

    # Ensure exact final shape (due to rounding in zoom) with padding/cropping
    # This is a simplified crop/pad. More robust padding might be needed.
    shape_diff = np.array(final_cube_shape_clf) - np.array(final_cube_np.shape)
    pad_dims = []
    for i in range(3): # D, H, W
        if shape_diff[i] >= 0: # Pad
            pad_before = shape_diff[i] // 2
            pad_after = shape_diff[i] - pad_before
            pad_dims.append((pad_before, pad_after))
        else: # Crop (should not happen if zoom is correct, but for robustness)
            crop_before = -shape_diff[i] // 2
            crop_after = -shape_diff[i] - crop_before
            final_cube_np = np.take(final_cube_np, range(crop_before, final_cube_np.shape[i] - crop_after), axis=i)
            pad_dims.append((0,0)) # No padding then

    if any(s[0]>0 or s[1]>0 for s in pad_dims): # Check if any padding is needed
         final_cube_np = np.pad(final_cube_np, pad_dims, mode='constant', constant_values=final_cube_np.min()) # Pad with min value

    # Final crop to exact size if over-padded/zoomed
    final_cube_np = final_cube_np[
        :final_cube_shape_clf[0],
        :final_cube_shape_clf[1],
        :final_cube_shape_clf[2]
    ]
    assert final_cube_np.shape == tuple(final_cube_shape_clf), \
        f"Cube shape mismatch: {final_cube_np.shape} vs {final_cube_shape_clf}"

    return final_cube_np.astype(np.float32)


# --- Generate lists of positive and negative cube paths/info for classifier training ---
# This is a complex step. You need to:
# 1. For GT nodules: Extract cubes, assign malignancy label (e.g., from LIDC 1-5 scale to binary).
#    Store cube path and label. These are your "positives" (if malignant) or "easy negatives" (if benign GT).
# 2. For proposed candidates from segmentation:
#    - Match them to GT nodules (e.g., by IoU of bounding boxes or distance of centroids).
#    - If a proposal matches a GT malignant nodule -> positive sample.
#    - If a proposal matches a GT benign nodule -> easy negative sample.
#    - If a proposal does NOT match any GT nodule -> potential hard negative sample.
#
# For now, a simplified placeholder:
# all_gt_nodule_cubes_info = [] # list of {'cube_path': path, 'label': 0/1, 'patient_id':pid}
# all_proposed_negative_cubes_info = [] # list of {'cube_path': path, 'label': 0, 'patient_id':pid}

# Example of processing and saving cubes (run this in a loop for all patients and nodules)
# for patient_id in all_patient_ids:
#     sitk_ct, _, annotations = load_patient_data(patient_id, CONFIG["data_dir"])
#     # For GT nodules
#     for annot in annotations:
#         if 'malignancy' in annot: # Assume malignancy is available
#             label = 1 if annot['malignancy'] >= 3 else 0 # Example threshold for binary
#             cube_np = extract_cube(sitk_ct, annot['centroid_world'],
#                                    CONFIG["clf_cube_size_raw"], CONFIG["target_spacing_clf"],
#                                    CONFIG["clf_cube_size_final"], CONFIG)
#             # Save cube_np and store info
#             # cube_path = CONFIG["output_dir"] / "classifier_cubes" / f"{patient_id}_{annot['id']}_gt_l{label}.npy"
#             # np.save(cube_path, cube_np)
#             # all_gt_nodule_cubes_info.append({'cube_path': cube_path, 'label': label, ...})

#     # For proposed nodules (run segmentation first)
#     # proposed_nodules_this_patient, _ = get_nodule_candidates_from_segmentation(...)
#     # For each proposed_nodule:
#         # Check if it's a GT match. If not, it's a potential negative.
#         # is_gt_match = check_match(proposed_nodule, annotations)
#         # if not is_gt_match:
#         #     cube_np = extract_cube(sitk_ct, proposed_nodule['centroid_world'], ...)
#             # Save cube_np for negatives
#             # all_proposed_negative_cubes_info.append({'cube_path': ..., 'label': 0, ...})

# %% [markdown]
# ## 7. Classifier Model (3D DenseNet121 + SE Blocks)
#
# - Define 3D DenseNet architecture with SE blocks.
# - Function to load MedicalNet pre-trained weights.
# - Define Focal Loss.
# - Define Dataset for classifier (handles cube loading, HNM sampling).
# - Training loop for classifier (incorporating HNM, cosine LR, early stopping).

# %%
# Placeholder for 3D DenseNet + SE (models_classification.py)

class SEBlock3D(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class _DenseLayer3D(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, use_se=True):
        super(_DenseLayer3D, self).__init__()
        self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)),
        if use_se:
            self.add_module('se', SEBlock3D(growth_rate))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer3D, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

class _DenseBlock3D(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, use_se=True):
        super(_DenseBlock3D, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer3D(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, use_se)
            self.add_module('denselayer%d' % (i + 1), layer)

class _Transition3D(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition3D, self).__init__()
        self.add_module('norm', nn.BatchNorm3d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv3d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))


class DenseNet3D_SE(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), # DenseNet-121 like
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1, use_se=True, in_channels=1):
        super(DenseNet3D_SE, self).__init__()

        self.features = nn.Sequential()
        # Initial convolution
        self.features.add_module('conv0', nn.Conv3d(in_channels, num_init_features, kernel_size=7, stride=2, padding=3, bias=False))
        self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
        self.features.add_module('relu0', nn.ReLU(inplace=True))
        self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1))

        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock3D(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, use_se=use_se)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition3D(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        self.features.add_module('norm5', nn.BatchNorm3d(num_features))
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool3d(out, (1, 1, 1)).view(features.size(0), -1)
        out = self.classifier(out)
        return out

def load_medicalnet_weights(model, weights_path, device):
    if weights_path and weights_path.exists():
        print(f"Loading MedicalNet weights from: {weights_path}")
        try:
            checkpoint = torch.load(weights_path, map_location=device)
            state_dict = checkpoint.get('state_dict', checkpoint) # Handle different checkpoint formats

            # Adapt keys if needed (e.g., MedicalNet might have 'module.' prefix from DataParallel)
            # Or if layer names differ (e.g. 'features.conv0' vs 'conv0')
            # This is a common pain point and requires inspecting both model's state_dict keys.
            # Example adaptation:
            new_state_dict = {}
            for k, v in state_dict.items():
                # name = k.replace("module.", "") # Remove `module.` prefix
                name = k # Assume keys match for now or adapt as per your MedicalNet checkpoint
                # Further adaptation may be needed depending on your DenseNet3D_SE vs MedicalNet's DenseNet
                new_state_dict[name] = v

            # Filter out mismatched keys (e.g. classifier if num_classes differs)
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
            
            missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, strict=False)
            print(f"Weights loaded. Missing keys: {missing_keys}")
            print(f"Unexpected keys in pretrained: {unexpected_keys}")
            if not pretrained_dict:
                 print("Warning: No weights were loaded from MedicalNet checkpoint. Check layer names and shapes.")

        except Exception as e:
            print(f"Error loading MedicalNet weights: {e}")
    else:
        print("MedicalNet weights path not found or not provided. Training from scratch.")


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs_logits, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs_logits, targets, reduction='none')
        probs = torch.sigmoid(inputs_logits)
        pt = torch.exp(-BCE_loss) # prevents nans when probability 0
        # pt = probs * targets + (1 - probs) * (1 - targets) # This is more direct for pt
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

class ClassifierDataset(Dataset):
    def __init__(self, positive_samples_info, negative_samples_info_all, epoch, config, transform=None):
        """
        positive_samples_info: list of dicts {'cube_path': path, 'label': 1}
        negative_samples_info_all: list of dicts {'cube_path': path, 'label': 0} (all available negatives)
        epoch: current epoch, for HNM logic
        """
        self.config = config
        self.epoch = epoch
        self.transform = transform # For augmentations on cubes

        self.positive_samples = positive_samples_info
        self.current_negative_samples = []

        if self.epoch < self.config["hnm_start_epoch"] or not negative_samples_info_all:
            # Randomly sample negatives or use all if few
            num_neg_to_sample = min(len(negative_samples_info_all),
                                    len(self.positive_samples) * self.config["hnm_ratio_neg_to_pos"])
            self.current_negative_samples = random.sample(negative_samples_info_all, num_neg_to_sample) \
                                            if negative_samples_info_all else []
        else:
            # HNM logic will be applied by the training loop before creating this dataset object for the epoch
            # This dataset will be created with the *already mined* hard negatives for this epoch
            self.current_negative_samples = negative_samples_info_all # Assume these are the hard ones

        self.all_samples = self.positive_samples + self.current_negative_samples
        random.shuffle(self.all_samples)

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx):
        sample_info = self.all_samples[idx]
        # cube_array = np.load(sample_info['cube_path'])
        # label = sample_info['label']
        # Dummy data for classifier dataset
        cube_array = np.random.rand(*self.config["clf_cube_size_final"]).astype(np.float32)
        label = random.choice([0,1]) # Dummy label

        # Augment cube_array if self.transform is defined
        # if self.transform: cube_array = self.transform(cube_array)

        cube_tensor = torch.from_numpy(cube_array).float().unsqueeze(0) # C, D, H, W
        label_tensor = torch.tensor(label, dtype=torch.float32) # .unsqueeze(0) for BCEWithLogits

        return cube_tensor, label_tensor


# --- Training Loop for Classifier (with HNM) ---
def train_classifier_epoch(model, dataloader, optimizer, loss_fn, scaler, device, epoch_num, config):
    model.train()
    epoch_loss = 0
    for batch_idx, (cubes, labels) in enumerate(tqdm(dataloader, desc=f"Clf Training Epoch {epoch_num+1}")):
        cubes, labels = cubes.to(device), labels.to(device).unsqueeze(1) # Ensure label is [B,1]
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()):
            predictions_logits = model(cubes)
            loss = loss_fn(predictions_logits, labels)
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN/Inf loss in clf train: {loss.item()}. Skipping batch.")
            continue
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

def validate_classifier_epoch(model, dataloader, loss_fn, device):
    model.eval()
    epoch_loss = 0
    all_preds_probs_clf, all_labels_clf = [], []
    with torch.no_grad():
        for cubes, labels in tqdm(dataloader, desc="Clf Validation Epoch"):
            cubes, labels = cubes.to(device), labels.to(device).unsqueeze(1)
            with autocast(enabled=torch.cuda.is_available()):
                predictions_logits = model(cubes)
                loss = loss_fn(predictions_logits, labels)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss in clf val: {loss.item()}. Skipping batch.")
                continue

            epoch_loss += loss.item()
            all_preds_probs_clf.extend(torch.sigmoid(predictions_logits).cpu().numpy())
            all_labels_clf.extend(labels.cpu().numpy())

    if not all_labels_clf: return 0.0, 0.0 # Handle empty validation

    all_labels_clf_np = np.array(all_labels_clf).flatten()
    all_preds_probs_clf_np = np.array(all_preds_probs_clf).flatten()

    val_auc = 0.0
    if len(np.unique(all_labels_clf_np)) > 1: # Check if more than one class present
        val_auc = roc_auc_score(all_labels_clf_np, all_preds_probs_clf_np)
    else:
        print("Warning: Only one class present in validation labels, AUC cannot be calculated.")

    return epoch_loss / len(dataloader), val_auc


# Initialize Classifier Model
clf_model = DenseNet3D_SE(
    num_classes=CONFIG["clf_num_classes"],
    in_channels=CONFIG["clf_in_channels"],
    # growth_rate, block_config, etc. can be adjusted
).to(DEVICE)

if CONFIG["clf_use_medicalnet_pretrained"]:
    load_medicalnet_weights(clf_model, CONFIG["medicalnet_weights_path"], DEVICE)

clf_optimizer = optim.AdamW(clf_model.parameters(), lr=CONFIG["clf_lr"], weight_decay=1e-5)
clf_loss_fn = FocalLoss(alpha=CONFIG["clf_focal_loss_alpha"], gamma=CONFIG["clf_focal_loss_gamma"])
clf_scaler = GradScaler(enabled=torch.cuda.is_available())
clf_lr_scheduler = CosineAnnealingLR(clf_optimizer, T_max=CONFIG["clf_cosine_lr_t_max"])
# For early stopping:
# early_stopper = ReduceLROnPlateau(clf_optimizer, mode='max', factor=0.5, patience=CONFIG["clf_early_stopping_patience"] // 2, verbose=True)
# Or custom early stopping logic based on val_auc not improving. For now, simplified.
best_val_auc_clf = -1.0
epochs_no_improve_clf = 0


# --- Dummy data for classifier training ---
# In a real scenario, these lists would be populated after cube extraction (Section 6)
# And `all_proposed_negative_cubes_info` would be updated with HNM results each epoch
num_dummy_pos = 50
num_dummy_neg_total = 200
dummy_positive_cubes_info = [{'cube_path': f'dummy_pos_{i}.npy', 'label': 1, 'patient_id': 'p_pos'} for i in range(num_dummy_pos)]
dummy_all_negative_cubes_info = [{'cube_path': f'dummy_neg_{i}.npy', 'label': 0, 'patient_id': 'p_neg'} for i in range(num_dummy_neg_total)]

# Split positives and negatives for validation to maintain similar distribution
dummy_pos_train, dummy_pos_val = train_test_split(dummy_positive_cubes_info, test_size=0.2, random_state=SEED)
dummy_neg_train_pool, dummy_neg_val_pool = train_test_split(dummy_all_negative_cubes_info, test_size=0.2, random_state=SEED)


print(f"\n--- Training Classifier Model ({CONFIG['clf_model_name']}) ---")
for epoch in range(CONFIG["clf_epochs"]):
    current_negatives_for_epoch = []
    if epoch < CONFIG["hnm_start_epoch"]:
        # Randomly sample negatives before HNM starts
        num_to_sample = min(len(dummy_neg_train_pool), len(dummy_pos_train) * CONFIG["hnm_ratio_neg_to_pos"])
        current_negatives_for_epoch = random.sample(dummy_neg_train_pool, num_to_sample) if dummy_neg_train_pool else []
    else:
        # --- HNM Step ---
        print(f"Epoch {epoch+1}: Performing Hard Negative Mining...")
        # Create a dataset of ALL available training negatives to score them
        temp_hnm_dataset = ClassifierDataset([], dummy_neg_train_pool, epoch, CONFIG) # Pass epoch for consistency
        temp_hnm_loader = DataLoader(temp_hnm_dataset, batch_size=CONFIG["clf_batch_size"] * 2, shuffle=False) # Larger batch for inference

        clf_model.eval() # Set model to evaluation mode for HNM scoring
        scored_negatives = [] # List of (score, negative_info_dict)
        with torch.no_grad():
            for cubes_neg, _ in tqdm(temp_hnm_loader, desc="HNM Scoring Negatives"):
                cubes_neg = cubes_neg.to(DEVICE)
                preds_logits_neg = clf_model(cubes_neg)
                preds_probs_neg = torch.sigmoid(preds_logits_neg).cpu().numpy().flatten()
                # Associate probs with their original info (requires careful indexing if dataset shuffles)
                # For simplicity, assume temp_hnm_dataset.all_samples retains order or map back
                start_idx = len(scored_negatives) # This simple indexing works if loader doesn't shuffle
                for i in range(len(preds_probs_neg)):
                     # This assumes temp_hnm_dataset.all_samples are the negatives from dummy_neg_train_pool in order
                    if start_idx + i < len(temp_hnm_dataset.all_samples):
                        neg_info = temp_hnm_dataset.all_samples[start_idx + i]
                        scored_negatives.append((preds_probs_neg[i], neg_info))


        scored_negatives.sort(key=lambda x: x[0], reverse=True) # Sort by prob (score), highest first (hardest)
        num_hard_neg = CONFIG["hnm_num_hard_negatives_per_batch"](len(dummy_pos_train), CONFIG["hnm_ratio_neg_to_pos"])
        # The above lambda is num hard neg per batch, we need for whole epoch
        # Correct: num_hard_neg = len(dummy_pos_train) * CONFIG["hnm_ratio_neg_to_pos"]
        num_hard_neg = min(len(scored_negatives), len(dummy_pos_train) * CONFIG["hnm_ratio_neg_to_pos"])

        current_negatives_for_epoch = [info for score, info in scored_negatives[:num_hard_neg]]
        print(f"  Selected {len(current_negatives_for_epoch)} hard negatives for this epoch.")
        clf_model.train() # Set model back to training mode

    clf_train_dataset = ClassifierDataset(dummy_pos_train, current_negatives_for_epoch, epoch, CONFIG)
    clf_train_loader = DataLoader(clf_train_dataset, batch_size=CONFIG["clf_batch_size"], shuffle=True, num_workers=0)

    # Validation dataset (fixed negatives, not from HNM pool unless you design it that way)
    num_val_neg_to_sample = min(len(dummy_neg_val_pool), len(dummy_pos_val) * CONFIG["hnm_ratio_neg_to_pos"])
    val_neg_samples = random.sample(dummy_neg_val_pool, num_val_neg_to_sample) if dummy_neg_val_pool else []
    clf_val_dataset = ClassifierDataset(dummy_pos_val, val_neg_samples, epoch, CONFIG) # epoch not used for val sampling logic here
    clf_val_loader = DataLoader(clf_val_dataset, batch_size=CONFIG["clf_batch_size"], shuffle=False, num_workers=0)

    if not clf_train_loader or len(clf_train_loader.dataset) == 0:
        print(f"Skipping training for epoch {epoch+1} due to empty train_loader.")
        continue

    train_loss_clf = train_classifier_epoch(clf_model, clf_train_loader, clf_optimizer, clf_loss_fn, clf_scaler, DEVICE, epoch, CONFIG)

    if not clf_val_loader or len(clf_val_loader.dataset) == 0:
        print(f"Skipping validation for epoch {epoch+1} due to empty val_loader.")
        val_loss_clf, val_auc_clf = float('inf'), 0.0 # Or previous values
    else:
        val_loss_clf, val_auc_clf = validate_classifier_epoch(clf_model, clf_val_loader, clf_loss_fn, DEVICE)

    clf_lr_scheduler.step()
    # early_stopper.step(val_auc_clf) # If using ReduceLROnPlateau

    print(f"Epoch {epoch+1}/{CONFIG['clf_epochs']}: Clf Train Loss: {train_loss_clf:.4f}, Clf Val Loss: {val_loss_clf:.4f}, Clf Val AUC: {val_auc_clf:.4f}")

    if val_auc_clf > best_val_auc_clf:
        best_val_auc_clf = val_auc_clf
        torch.save(clf_model.state_dict(), CONFIG["output_dir"] / "classification_models" / f"{CONFIG['clf_model_name']}_best.pth")
        print(f"  Saved best classifier model with Val AUC: {best_val_auc_clf:.4f}")
        epochs_no_improve_clf = 0
    else:
        epochs_no_improve_clf += 1

    if epochs_no_improve_clf >= CONFIG["clf_early_stopping_patience"]:
        print(f"Early stopping triggered for classifier at epoch {epoch+1} due to no improvement in Val AUC.")
        break
    # if clf_optimizer.param_groups[0]['lr'] < 1e-7: # Another early stopping condition if LR gets too small
    #     print("Learning rate too small, stopping early.")
    #     break

print("Classifier training finished.")


# %% [markdown]
# ## 8. End-to-End Evaluation
#
# - On a test set of patients:
#   1. Run segmentation model to get probability map.
#   2. Calculate Detection Dice:
#      - Compare predicted nodule segmentation map (after thresholding) with ground truth nodule segmentation map.
#      - This requires GT nodule *masks*, not just centroids.
#   3. Perform candidate proposal (as in section 5).
#   4. For each proposed candidate, extract cube and run classifier.
#   5. For classification performance:
#      - Match proposals to GT nodules to assign true labels to proposals.
#      - Calculate AUC on these matched proposals.
#      - Calculate overall Sensitivity @ 95% Specificity (FROC-like analysis might be more appropriate here).
#
# This is complex. A simplified version might be:
# - Evaluate segmentation Dice on test set.
# - Evaluate classifier AUC on pre-extracted GT test cubes + proposed hard negatives from test set.
# - A true end-to-end metric would consider detection recall and then classification accuracy on correctly detected nodules.

# %%
# Placeholder for end-to-end evaluation logic.
# This would involve iterating through test_ids, running the full pipeline.

# --- Detection Dice Evaluation ---
# Requires GT segmentation masks for nodules on the test set.
# Create these similar to how `nodule_seg_mask_np` was created for training.
# Then, for each test patient:
#   pred_seg_map_np = get_segmentation_output(test_ct_sitk, seg_model, ...)
#   pred_binary_map = (pred_seg_map_np > CONFIG["seg_prob_threshold"])
#   dice = calculate_dice(pred_binary_map, gt_nodule_mask_np)
#   Store and average Dice scores.

# --- Classification Metrics on Test Set ---
# Load best classifier model
best_clf_model_path = CONFIG["output_dir"] / "classification_models" / f"{CONFIG['clf_model_name']}_best.pth"
if best_clf_model_path.exists():
    clf_model.load_state_dict(torch.load(best_clf_model_path, map_location=DEVICE))
    print(f"Loaded best classifier model from {best_clf_model_path}")
else:
    print("WARNING: Best classifier model not found. Using last epoch model for final eval.")


# Generate test cubes (GT positives and proposed negatives from test set scans)
# dummy_test_pos_cubes_info = ...
# dummy_test_neg_cubes_info_pool = ... (all negatives found on test scans)
# For a fair AUC, you might want a balanced set of test negatives or use all FPs from detection.

# For this example, re-use validation cube lists as dummy test lists
clf_test_dataset = ClassifierDataset(dummy_pos_val, val_neg_samples, CONFIG["clf_epochs"], CONFIG) # epoch doesn't matter for fixed test set
clf_test_loader = DataLoader(clf_test_dataset, batch_size=CONFIG["clf_batch_size"], shuffle=False)

test_loss_clf, test_auc_clf = validate_classifier_epoch(clf_model, clf_test_loader, clf_loss_fn, DEVICE) # Reuses validation function
print(f"\n--- Classifier Test Set Performance ---")
print(f"Test Loss: {test_loss_clf:.4f}, Test AUC: {test_auc_clf:.4f}")

# For Sensitivity @ Specificity:
# Need all_preds_probs and all_labels from the test set (returned by validate_classifier_epoch if adapted)
# Example: (Assuming validate_classifier_epoch is modified to return probs and labels)
# _, final_test_probs, final_test_labels = validate_classifier_epoch_for_eval(clf_model, clf_test_loader, clf_loss_fn, DEVICE)
# fpr, tpr, thresholds = roc_curve(final_test_labels, final_test_probs)
# # Find threshold for 95% specificity (Specificity = 1 - FPR)
# target_fpr = 1.0 - CONFIG["eval_sensitivity_at_specificity"]
# closest_fpr_idx = np.argmin(np.abs(fpr - target_fpr))
# sensitivity_at_target_spec = tpr[closest_fpr_idx]
# threshold_at_target_spec = thresholds[closest_fpr_idx]
# print(f"Sensitivity at {CONFIG['eval_sensitivity_at_specificity']*100}% Specificity: {sensitivity_at_target_spec:.4f} (Threshold: {threshold_at_target_spec:.4f})")


# %% [markdown]
# ## 9. Visualization
#
# - **3D Overlay:** Show CT scan with predicted nodule segmentations (or bounding boxes) overlaid.
#   - `itkwidgets` or slice-by-slice `matplotlib`.
# - **GradCAM:** For the classifier, on input cubes.
#   - Adapt standard GradCAM implementations for 3D and your DenseNet architecture.
#   - This requires access to the model's convolutional feature maps and gradients.

# %%
# Placeholder for 3D overlay visualization (visualization_tools.py)
def visualize_3d_overlay(ct_np_array, seg_mask_np_array, slice_idx=None, title=""):
    """ Simple slice-wise overlay """
    if slice_idx is None:
        slice_idx = ct_np_array.shape[0] // 2 # Middle slice for Depth
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1,2,1)
    plt.imshow(ct_np_array[slice_idx, :, :], cmap='gray')
    plt.title(f"CT Slice {slice_idx}")
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.imshow(ct_np_array[slice_idx, :, :], cmap='gray')
    plt.imshow(seg_mask_np_array[slice_idx, :, :], cmap='jet', alpha=0.5, vmin=0, vmax=1) # Overlay mask
    plt.title(f"CT + Seg Overlay Slice {slice_idx}")
    plt.axis('off')
    if title: plt.suptitle(title)
    plt.show()

# Example:
# test_patient_ct_np = preprocess_image(sitk_ct_test, CONFIG["target_spacing_seg"], ...)
# test_patient_pred_seg_map_np = seg_map_for_dice # from get_nodule_candidates
# visualize_3d_overlay(test_patient_ct_np, test_patient_pred_seg_map_np > CONFIG["seg_prob_threshold"], title=f"Pred Seg for {patient_to_test_seg}")


# Placeholder for GradCAM (visualization_tools.py)
# This is non-trivial to implement fully here.
# Key steps:
# 1. Hook the target convolutional layer in your classifier.
# 2. Forward pass with the cube to get activations and logits.
# 3. Backward pass from the target class logit to get gradients w.r.t activations.
# 4. Compute weights (global average pool of gradients).
# 5. Compute weighted sum of activation maps.
# 6. ReLU and normalize heatmap.
# 7. Upsample heatmap to cube size and overlay.

# %% [markdown]
# ## 10. Logging Metrics, ROC Curves, and Final Report
#
# - Log all relevant metrics (Dice, AUC, Loss, Acc, Sens/Spec) for train/val/test.
# - Plot and save ROC curves for the classifier.
# - **Final Report Generation:**
#   - Summarize dataset characteristics.
#   - Segmentation model performance (Dice).
#   - Candidate proposal statistics (e.g., number of candidates per scan, FPs).
#   - Classifier performance (AUC, Sens@Spec, other metrics from classification_report).
#   - **Comparison: Pure 3D CNN vs. Cascade Approach:**
#     - To do this properly, you'd need to define, train, and evaluate a "Pure 3D CNN" for direct nodule classification (e.g., on the same extracted cubes, or an end-to-end detection model).
#     - For this notebook, we can describe what such a model would be and qualitatively compare the expected pros/cons based on the cascade's performance.

# %%
# --- Plot ROC Curve for Classifier ---
# Assuming final_test_labels, final_test_probs are available from evaluation
# fpr, tpr, thresholds = roc_curve(final_test_labels, final_test_probs)
# roc_auc_value = auc(fpr, tpr)

# plt.figure(figsize=(8,6))
# plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_value:.2f})')
# plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
# plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
# plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
# plt.title('Classifier ROC Curve (Test Set)')
# plt.legend(loc="lower right"); plt.grid(True)
# plt.savefig(CONFIG["output_dir"] / "reports" / "classifier_roc_curve.png")
# plt.show()

# --- Final Report (Markdown or text file) ---
report_content = f"""
# Lung Nodule Detection and Classification Pipeline Report

## 1. Dataset
- Source: (e.g., LIDC-IDRI subset)
- Number of patients (Train/Val/Test): {len(train_ids)}/{len(val_ids)}/{len(test_ids)}
- Annotation details: (e.g., nodule centroids, diameters, malignancy scores)

## 2. Preprocessing
- CT Resampling Spacing (Seg): {CONFIG["target_spacing_seg"]}
- CT Resampling Spacing (Clf): {CONFIG["target_spacing_clf"]}
- HU Clipping: {CONFIG["hu_clip_bounds"]}
- Normalization: Mean={CONFIG["norm_mean_std"]["mean"]}, Std={CONFIG["norm_mean_std"]["std"]}

## 3. Cascade Pipeline Performance

### 3.1. Segmentation Model ({CONFIG['seg_model_name']})
- Architecture: Custom 3D U-Net
- Training Epochs: {CONFIG["seg_epochs"]}
- Best Validation Dice: {best_val_dice_seg:.4f} (if tracked)
- Test Set Dice: (Calculate this on test set)

### 3.2. Candidate Proposal
- Method: Thresholding segmentation map ({CONFIG["seg_prob_threshold"]}) + Connected Components
- Min Nodule Size: {CONFIG["min_nodule_size_voxels"]} voxels
- Statistics: (e.g., Avg candidates/scan, False Positive Rate of proposals - requires matching to GT)

### 3.3. Classification Model ({CONFIG['clf_model_name']})
- Architecture: 3D DenseNet121 with SE Blocks
- Pre-trained weights: {"MedicalNet" if CONFIG["clf_use_medicalnet_pretrained"] else "From Scratch"}
- Loss Function: Focal Loss (alpha={CONFIG["clf_focal_loss_alpha"]}, gamma={CONFIG["clf_focal_loss_gamma"]})
- Training Strategy: Hard Negative Mining (ratio={CONFIG["hnm_ratio_neg_to_pos"]}), Cosine LR, Early Stopping (patience={CONFIG["clf_early_stopping_patience"]})
- Training Epochs: {CONFIG["clf_epochs"]} (or actual if early stopped)
- Best Validation AUC: {best_val_auc_clf:.4f}
- **Test Set Performance:**
    - AUC: {test_auc_clf:.4f} (if calculated)
    - Sensitivity @ {CONFIG["eval_sensitivity_at_specificity"]*100}% Specificity: (Calculate this)
    - Other metrics: (Add Precision, Recall, F1 from classification_report on test proposals)

## 4. Comparison: Cascade vs. Pure 3D CNN Approach

### 4.1. Pure 3D CNN (Hypothetical or Implemented)
- Define the architecture (e.g., {CONFIG['comp_model_name']} - a 3D ResNet or simpler CNN).
- Training: On extracted cubes directly for malignancy classification.
- Performance: (Report its AUC, Sens/Spec if implemented and trained).

### 4.2. Discussion
- **Cascade Approach Pros:**
    - Modular: Can optimize segmentation and classification separately.
    - Segmentation can find a wide range of candidates.
    - Classifier focuses only on small, relevant cubes.
    - May handle varying nodule sizes well if segmentation is robust.
- **Cascade Approach Cons:**
    - Error propagation: Segmentation errors directly impact classification.
    - Complex pipeline to manage and tune.
    - Candidate generation step can be critical and hard to optimize (many FPs).
- **Pure 3D CNN (Direct Classification/Detection) Pros:**
    - Potentially simpler end-to-end training.
    - Learns features directly for the final task.
    - No intermediate hard-to-tune proposal stage if using an end-to-end detector.
- **Pure 3D CNN Cons:**
    - If classifying cubes directly, still needs a good proposal mechanism or processes whole scan patches (computationally expensive).
    - End-to-end detectors (like 3D RetinaNet) are complex to implement and train, very data-hungry.
    - May struggle with highly imbalanced data if not handled carefully (many non-nodule regions).

## 5. Conclusion and Future Work
- Summary of findings.
- Limitations of the current pipeline.
- Potential improvements (e.g., better data augmentation, more advanced HNM, different model architectures, full end-to-end model).
"""

report_path = CONFIG["output_dir"] / "reports" / "final_pipeline_report.md"
with open(report_path, 'w') as f:
    f.write(report_content)
print(f"Final report structure saved to {report_path}")


# %% [markdown]
# ## End of Notebook

Using device: cuda


TypeError: Object of type function is not JSON serializable