In [1]:
# %% [markdown]
# # Refined Advanced End-to-End Lung Nodule Detection and Classification Pipeline
#
# This notebook implements a refined advanced classifier model balancing novel components with practical trainability
# for limited datasets.
#
# **Classifier Architecture (`AdvancedHybridNetRefined`):**
# 1. **Anisotropic Convolutional Stem.**
# 2. **Hierarchical Backbone:** Stages of custom "DenseResLayers" with SE and spatial dropout,
#    followed by downsampling. Simplified attention applied at one or two scales.
# 3. **Convolution-Free Transformer Head.**
#
# **Training Strategies:** As previously outlined (Focal Loss, Augmentations, AdamW, Cosine LR).

# %% [markdown]
# ## 1. Setup and Imports (Same as previous)

# %%
import os
import glob
import time
import random
import json
from pathlib import Path
import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import seaborn as sns # Optional

import SimpleITK as sitk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score,
    roc_curve, precision_recall_curve, auc, f1_score,
    precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay
)
from skimage.measure import label as skimage_label, regionprops
import scipy.ndimage as ndi

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
seed_everything(SEED)

# %% [markdown]
# ## 2. Refined Configuration

# %%
CONFIG = {
    # --- Paths ---
    "data_dir": Path("/path/to/your/LIDC-IDRI-like/dataset"), # Placeholder
    "output_dir": Path("./output_nodule_pipeline_refined"),
    # "medicalnet_weights_path": Path("..."), # Unlikely to be directly compatible

    # --- General Preprocessing ---
    "target_spacing_clf": [1.0, 1.0, 1.0],
    "hu_clip_bounds": [-1000, 400],
    "norm_mean_std": {"mean": 0.25, "std": 0.25},

    # --- Segmentation & Proposal (Assume reused/pre-computed for classifier focus) ---
    "seg_model_name": "Pretrained_Custom3DUNet",
    "seg_output_dir": Path("./output_nodule_pipeline/segmentation_models"),
    "seg_prob_threshold": 0.5,
    "min_nodule_size_voxels": 20,

    # --- Candidate Cube Generation ---
    "clf_cube_size_final": [40, 48, 48], # D, H, W input to classifier

    # --- Refined Classifier Model (`AdvancedHybridNetRefined`) ---
    "clf_model_name": "AdvancedHybridNetRefined",
    "clf_in_channels": 1,
    "clf_num_classes": 1,
    "clf_anisotropic_stem_out_channels": 32,
    # Backbone Stage Config: list of tuples (num_dense_res_layers, out_channels_stage, use_attention_after_stage)
    "clf_backbone_stages_config": [
        (2, 64, False), # Stage 1: 2 DenseResLayers, output 64 channels, no specific attention here
        (3, 128, True), # Stage 2: 3 DenseResLayers, output 128 channels, apply attention
        (4, 256, True), # Stage 3: 4 DenseResLayers, output 256 channels, apply attention
    ],
    "clf_dense_res_growth_rate": 16,
    "clf_dense_res_bn_size": 4, # Bottleneck factor in dense_res_layer
    "clf_dense_res_se_reduction": 8,
    "clf_dense_res_spatial_dropout": 0.1,
    "clf_dense_res_stochastic_depth_prob": 0.1, # For DropPath around the layer output
    "clf_attention_module_heads": 4, # For attention modules after stages
    # Transformer Head (operates on output of last backbone stage)
    "clf_transformer_head_patch_size": (4, 4, 4), # (D,H,W) for tokenization
    "clf_transformer_head_embed_dim": 128,
    "clf_transformer_head_depth": 2,
    "clf_transformer_head_num_heads": 4,

    # --- Training & Optimization ---
    "clf_batch_size": 8,
    "clf_lr_initial": 1e-4, # Slightly reduced from 1e-3
    "clf_weight_decay": 1e-5,
    "clf_epochs": 75, # Adjusted
    "clf_focal_loss_alpha": 0.25,
    "clf_focal_loss_gamma": 2.0,
    "clf_early_stopping_patience": 15,
    "clf_cosine_lr_t_max": 75,

    # Augmentation
    "aug_elastic_alpha_sigma": ((0, 360.0), (7.0, 9.0)), # (alpha_range, sigma_range) for elastic
    "aug_intensity_scale_range": (0.8, 1.2),
    "aug_mixup_alpha": 0.2,

    # HNM
    "hnm_start_epoch": 10,
    "hnm_ratio_neg_to_pos": 2,

    # Evaluation
    "eval_sensitivity_at_specificity": 0.95,
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
(CONFIG["output_dir"] / "classification_models_refined").mkdir(exist_ok=True) # New folder
(CONFIG["output_dir"] / "visualizations_refined").mkdir(exist_ok=True)
(CONFIG["output_dir"] / "reports_refined").mkdir(exist_ok=True)

with open(CONFIG["output_dir"] / "config_refined.json", 'w') as f:
    json.dump({k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()}, f, indent=4)
print("Refined Configuration saved.")

# %% [markdown]
# ## 3. Data Handling (Dataset, Augmentation, Cube Extraction - Reuse & Adapt)

# %%
# --- Preprocessing & Cube Extraction (load_patient_data, preprocess_image, extract_cube etc. are assumed defined)
# --- For brevity, these functions are not repeated here. Import them or define them from previous cells.

# --- DropPath (Stochastic Depth) ---
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training: return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x.div(keep_prob) * random_tensor
        return output

# --- Augmentation Functions (Placeholders - Implement using libraries like TorchIO or batchgenerators) ---
def augment_elastic_deformation_3d_placeholder(volume, alpha_sigma_ranges, random_state=None):
    # print("Placeholder: Elastic Deformation")
    # Example: alpha = random.uniform(*alpha_sigma_ranges[0])
    #          sigma = random.uniform(*alpha_sigma_ranges[1])
    #          # Apply transform...
    return volume

def augment_intensity_scaling_3d_placeholder(volume, scale_range, random_state=None):
    # print("Placeholder: Intensity Scaling")
    if random_state is None: random_state = np.random.RandomState(None)
    scale = random_state.uniform(scale_range[0], scale_range[1])
    return volume * scale


class RefinedClassifierDataset(Dataset): # Adapted from AdvancedClassifierDataset
    def __init__(self, positive_samples_info, negative_samples_info_for_epoch, epoch, config, is_train=True):
        self.config = config
        self.epoch = epoch
        self.is_train = is_train
        self.all_samples = positive_samples_info + negative_samples_info_for_epoch
        if is_train: random.shuffle(self.all_samples)

        self.aug_elastic_alpha_sigma = config.get("aug_elastic_alpha_sigma", ((0,0),(1,1)))
        self.aug_intensity_scale_range = config.get("aug_intensity_scale_range", (1.0, 1.0))

    def __len__(self): return len(self.all_samples)
    def __getitem__(self, idx):
        sample_info = self.all_samples[idx]
        # cube_array = np.load(sample_info['cube_path']) # Load actual cube
        # label = sample_info['label']

        # Dummy data for execution
        cube_array = np.random.rand(*self.config["clf_cube_size_final"]).astype(np.float32) - 0.5 # Centered around 0
        label = random.choice([0, 1])

        if self.is_train:
            if random.random() < 0.3: # Apply elastic with some probability
                 cube_array = augment_elastic_deformation_3d_placeholder(cube_array, self.aug_elastic_alpha_sigma)
            if random.random() < 0.3: # Apply intensity scaling
                 cube_array = augment_intensity_scaling_3d_placeholder(cube_array, self.aug_intensity_scale_range)
            # Add other augmentations: gamma, noise, blur, mirror as in nnU-Net

        cube_tensor = torch.from_numpy(cube_array.copy()).float().unsqueeze(0)
        label_tensor = torch.tensor(label, dtype=torch.float32)
        return cube_tensor, label_tensor, sample_info

# --- Candidate lists & splits (Same dummy data generation as before for now) ---
num_dummy_pos_ref = 60
num_dummy_neg_total_ref = 250
dummy_positive_cubes_info_ref = [{'cube_path': f'dummy_pos_ref_{i}.npy', 'label': 1, 'id':f'pos{i}'} for i in range(num_dummy_pos_ref)]
dummy_all_negative_cubes_info_ref = [{'cube_path': f'dummy_neg_ref_{i}.npy', 'label': 0, 'id':f'neg{i}'} for i in range(num_dummy_neg_total_ref)]
dummy_pos_train_ref, dummy_pos_val_ref = train_test_split(dummy_positive_cubes_info_ref, test_size=0.2, random_state=SEED)
dummy_neg_train_pool_ref, dummy_neg_val_pool_ref = train_test_split(dummy_all_negative_cubes_info_ref, test_size=0.2, random_state=SEED)


# %% [markdown]
# ## 4. Refined Classifier Model Definition (`AdvancedHybridNetRefined`)

# %%
# --- Component Blocks (SEBlock3D, AnisotropicConvModule, AttentionModule, TransformerEncoderLayer, ConvFreeTransformerHead)
# --- Assume these are defined as in the previous "Advanced" blueprint. For brevity, not repeating all.
# --- We will define a new "DenseResLayer" suitable for the backbone.

class SEBlock3D(nn.Module): # Copied for self-containment
    def __init__(self, channel, reduction=16):
        super(SEBlock3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())
    def forward(self, x):
        b, c, _, _, _ = x.size(); y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1); return x * y.expand_as(x)

class AnisotropicConvModule(nn.Module): # Copied
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_xy = nn.Conv3d(in_channels, out_channels // 2, kernel_size=(1,3,3), padding=(0,1,1), bias=False)
        self.bn_xy = nn.BatchNorm3d(out_channels // 2)
        self.conv_z = nn.Conv3d(in_channels, out_channels // 2, kernel_size=(3,1,1), padding=(1,0,0), bias=False)
        self.bn_z = nn.BatchNorm3d(out_channels // 2)
        self.relu = nn.LeakyReLU(0.01, inplace=True) # Use LeakyReLU
        self.fuse_conv = nn.Conv3d(out_channels, out_channels, kernel_size=1, bias=False)
        self.bn_fuse = nn.BatchNorm3d(out_channels)
    def forward(self, x):
        x_xy = self.relu(self.bn_xy(self.conv_xy(x)))
        x_z = self.relu(self.bn_z(self.conv_z(x)))
        x_cat = torch.cat([x_xy, x_z], dim=1)
        x_fused = self.relu(self.bn_fuse(self.fuse_conv(x_cat)))
        return x_fused

class AttentionModule(nn.Module): # Copied (Simplified Spatial Attention)
    def __init__(self, in_channels, num_heads=4, proj_drop=0.1):
        super().__init__()
        self.num_heads = num_heads; self.head_dim = in_channels // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv_conv = nn.Conv3d(in_channels, in_channels * 3, kernel_size=1, bias=False)
        self.proj_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x): # x: B, C, D, H, W
        B, C, D, H, W = x.shape
        qkv = self.qkv_conv(x).reshape(B, 3, self.num_heads, self.head_dim, D*H*W).permute(1,0,2,4,3) # 3,B,nH,L,hd
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale; attn = attn.softmax(dim=-1)
        x_attn = (attn @ v).transpose(-2,-1).reshape(B,C,D,H,W)
        return self.proj_drop(self.proj_conv(x_attn))


# --- New DenseResLayer for Backbone Stages ---
class DenseResLayer(nn.Module):
    """
    A layer inspired by DenseNet and ResNet:
    - BN -> ReLU -> Conv1x1x1 (bottleneck) -> BN -> ReLU -> Conv3x3x3
    - Output of Conv3x3x3 is `growth_rate` channels.
    - This output is concatenated with the input to the layer (feature reuse from DenseNet).
    - A residual connection (1x1x1 conv if channels change) around this new concatenated feature map.
    - Includes SE, Spatial Dropout, Stochastic Depth (DropPath).
    """
    def __init__(self, in_channels, growth_rate, bn_size, se_reduction, spatial_dropout, stochastic_depth_prob):
        super().__init__()
        self.stochastic_depth = DropPath(stochastic_depth_prob)
        bottleneck_channels = bn_size * growth_rate

        self.conv_block = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Conv3d(in_channels, bottleneck_channels, kernel_size=1, bias=False),
            nn.BatchNorm3d(bottleneck_channels),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Conv3d(bottleneck_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        self.se = SEBlock3D(growth_rate, reduction=se_reduction)
        self.spatial_dropout = nn.Dropout3d(spatial_dropout) if spatial_dropout > 0 else nn.Identity()

        # For the residual connection around the (input + new_features_from_block)
        # The output of this layer will be input_channels + growth_rate
        # If we want a residual connection that sums with the input, projection is needed.
        # For simplicity of a "DenseRes" idea, we'll make the output be input_channels + growth_rate
        # and the "residual" aspect is more in spirit of robust feature propagation.
        # Or, let the block output be growth_rate channels and add it to input after projection.

        # Let's go with: Layer computes `growth_rate` features.
        # These `growth_rate` features are then added to the input (if channels match, or via projection)
        # AND concatenated for the next layer if part of a "dense" sequence.

        # Simpler for this context: Each layer produces 'growth_rate' new features,
        # which are concatenated to the input to form the output of this "dense_res_layer"
        # The 'residual' part is the DropPath making it sometimes an identity for the new features.

    def forward(self, x): # x is input tensor to this specific layer
        new_features = self.conv_block(x)
        new_features = self.se(new_features)
        new_features = self.spatial_dropout(new_features)
        new_features_maybe_dropped = self.stochastic_depth(new_features)
        
        # DenseNet-style: concatenate input with new features
        output = torch.cat([x, new_features_maybe_dropped], dim=1)
        return output


# --- Convolution-Free Transformer Head (Copied, ensure it's loaded) ---
class TransformerEncoderLayer(nn.Module): # Copied
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = norm_layer(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), act_layer(), nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim), nn.Dropout(dropout))
        self.dropout = nn.Dropout(dropout) # For residual path

    def forward(self, src):
        src_norm = self.norm1(src)
        src_attn, _ = self.attn(src_norm, src_norm, src_norm, need_weights=False)
        src = src + self.dropout(src_attn)
        src_norm = self.norm2(src)
        src_mlp = self.mlp(src_norm)
        src = src + self.dropout(src_mlp)
        return src

class ConvFreeTransformerHead(nn.Module): # Copied
    def __init__(self, in_channels, feature_map_size_dhw, patch_size_dhw, embed_dim, depth, num_heads, num_classes):
        super().__init__()
        self.patch_d, self.patch_h, self.patch_w = patch_size_dhw
        self.feat_d, self.feat_h, self.feat_w = feature_map_size_dhw
        assert self.feat_d % self.patch_d == 0 and self.feat_h % self.patch_h == 0 and self.feat_w % self.patch_w == 0, "Feature map not divisible by patch size"
        self.num_patches = (self.feat_d//self.patch_d) * (self.feat_h//self.patch_h) * (self.feat_w//self.patch_w)
        patch_dim_flat = in_channels * self.patch_d * self.patch_h * self.patch_w
        self.patch_proj = nn.Linear(patch_dim_flat, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim)); nn.init.trunc_normal_(self.cls_token,std=.02)
        self.pos_embed = nn.Parameter(torch.zeros(1,self.num_patches+1,embed_dim)); nn.init.trunc_normal_(self.pos_embed,std=.02)
        self.transformer_layers = nn.Sequential(*[TransformerEncoderLayer(embed_dim,num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim,num_classes)
    def forward(self,x): # B,C,Df,Hf,Wf
        B,C,_,_,_ = x.shape
        # unfold in PyTorch can be tricky for multiple dimensions simultaneously for patching.
        # Alternative: view + permute
        # B, C, nPd*pD, nPh*pH, nPw*pW -> B, C, nPd, pD, nPh, pH, nPw, pW
        x_patched = x.view(B, C,
                           self.feat_d // self.patch_d, self.patch_d,
                           self.feat_h // self.patch_h, self.patch_h,
                           self.feat_w // self.patch_w, self.patch_w)
        x_patched = x_patched.permute(0,2,4,6,1,3,5,7).contiguous() # B,nPd,nPh,nPw, C,pD,pH,pW
        x_tokens = x_patched.view(B, self.num_patches, -1) # B, num_patches, C*pD*pH*pW
        x_proj = self.patch_proj(x_tokens)
        cls = self.cls_token.expand(B,-1,-1)
        x_emb = torch.cat((cls,x_proj),dim=1) + self.pos_embed
        x_tf = self.transformer_layers(x_emb)
        return self.head(self.norm(x_tf[:,0]))


# --- Refined Main Classifier Model ---
class AdvancedHybridNetRefined(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        in_c = config["clf_in_channels"]
        stem_out_c = config["clf_anisotropic_stem_out_channels"]

        self.stem = AnisotropicConvModule(in_c, stem_out_c)
        current_channels = stem_out_c
        
        # D, H, W after stem (AnisotropicConvModule doesn't change spatial dims with padding)
        # But check your implementation if it does include strides.
        current_dhw = list(config["clf_cube_size_final"])

        self.backbone_stages = nn.ModuleList()
        self.attention_stages = nn.ModuleList() # Store attention modules if used
        self.downsample_layers = nn.ModuleList()

        for i, (num_layers, stage_out_c, use_attention) in enumerate(config["clf_backbone_stages_config"]):
            stage_layers = nn.ModuleList()
            # Input to the first DenseResLayer of this stage is `current_channels`
            stage_in_channels = current_channels
            for _ in range(num_layers):
                layer = DenseResLayer(
                    in_channels=current_channels, # Input to this specific layer
                    growth_rate=config["clf_dense_res_growth_rate"],
                    bn_size=config["clf_dense_res_bn_size"],
                    se_reduction=config["clf_dense_res_se_reduction"],
                    spatial_dropout=config["clf_dense_res_spatial_dropout"],
                    stochastic_depth_prob=config["clf_dense_res_stochastic_depth_prob"]
                )
                stage_layers.append(layer)
                current_channels += config["clf_dense_res_growth_rate"] # Output of DenseResLayer
            self.backbone_stages.append(stage_layers) # Add list of layers for this stage

            if use_attention:
                # Apply attention to the output of the stage (after all DenseResLayers in it)
                # The input to attention is `current_channels`
                self.attention_stages.append(AttentionModule(current_channels, config["clf_attention_module_heads"]))
            else:
                self.attention_stages.append(nn.Identity()) # Placeholder if no attention

            if i < len(config["clf_backbone_stages_config"]) - 1: # Add downsampling except for last stage
                # Transition layer halves channels from current_channels
                # We use a simple MaxPool3D for downsampling for robustness of channel numbers.
                # A _Transition3D could also be used but ensure `stage_out_c` is respected.
                # For simplicity, let's make `stage_out_c` the target *after* downsampling for next stage.
                # However, DenseResLayer keeps adding. So current_channels is output of dense stage.
                # The config "stage_out_c" might be better interpreted as the channel dim *after* transition for the *next* stage
                # Or, simpler: the transition layer brings `current_channels` to `next_stage_in_channels`.
                
                # Using MaxPool + 1x1 Conv to control channels and downsample
                ds_layer = nn.Sequential(
                    nn.MaxPool3d(kernel_size=2, stride=2),
                    nn.Conv3d(current_channels, config["clf_backbone_stages_config"][i+1][1], # target channels for next stage
                              kernel_size=1, bias=False),
                    nn.BatchNorm3d(config["clf_backbone_stages_config"][i+1][1]),
                    nn.LeakyReLU(0.01, inplace=True)
                )
                self.downsample_layers.append(ds_layer)
                current_channels = config["clf_backbone_stages_config"][i+1][1] # Update for next stage's input
                current_dhw = [d // 2 for d in current_dhw]
            else: # Last stage
                # No downsampling layer after the final backbone stage feeding the head
                pass


        self.final_feature_map_channels = current_channels
        self.final_feature_map_dhw = tuple(current_dhw)
        print(f"DEBUG: Final CNN feature map D,H,W before head: {self.final_feature_map_dhw}")
        print(f"DEBUG: Final CNN feature map Channels before head: {self.final_feature_map_channels}")


        self.transformer_head = ConvFreeTransformerHead(
            in_channels=self.final_feature_map_channels,
            feature_map_size_dhw=self.final_feature_map_dhw,
            patch_size_dhw=config["clf_transformer_head_patch_size"],
            embed_dim=config["clf_transformer_head_embed_dim"],
            depth=config["clf_transformer_head_depth"],
            num_heads=config["clf_transformer_head_num_heads"],
            num_classes=config["clf_num_classes"]
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm3d, nn.LayerNorm)):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')


    def forward(self, x):
        x = self.stem(x)
        
        for i, stage_module_list in enumerate(self.backbone_stages):
            for layer in stage_module_list: # Iterate through DenseResLayers in the stage
                x = layer(x)
            
            x = self.attention_stages[i](x) # Apply attention (or Identity)
            
            if i < len(self.downsample_layers):
                x = self.downsample_layers[i](x)
        
        # x is now the final feature map from the CNN backbone
        logits = self.transformer_head(x)
        return logits

# %% [markdown]
# ## 5. Training Setup & Loop (Refined)

# %%
# --- Loss Functions (FocalLoss, CombinedClfLoss - Assume defined as before) ---
class FocalLoss(nn.Module): # Copied
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__(); self.alpha=alpha; self.gamma=gamma; self.reduction=reduction
    def forward(self, inputs_logits, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs_logits, targets, reduction='none')
        pt = torch.exp(-BCE_loss); F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        if self.reduction == 'mean': return torch.mean(F_loss)
        return F_loss # Or sum

# --- MixUp (Assume defined as before) ---
def mixup_data(x, y, alpha=0.2, device='cuda'):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1.0
    batch_size = x.size(0); index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1-lam) * x[index, :]; y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# --- Training/Validation Epoch Functions (train_classifier_epoch_advanced, validate_classifier_epoch - Reuse/Adapt) ---
# Make sure train_classifier_epoch_advanced handles mixup as shown previously.
# For brevity, not repeating them here but ensure they are defined using CONFIG.

def train_classifier_epoch_refined(model, dataloader, optimizer, loss_fn, scaler, device, epoch_num, config):
    model.train()
    epoch_loss = 0.0
    num_batches = len(dataloader)
    if num_batches == 0: return 0.0
    
    progress_bar = tqdm(dataloader, desc=f"RefClf Train E{epoch_num+1}", leave=False)
    for cubes, labels, _ in progress_bar:
        cubes, labels = cubes.to(device), labels.to(device).unsqueeze(1)
        
        use_mixup = config.get("aug_mixup_alpha", 0) > 0 and random.random() < 0.5
        
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()):
            if use_mixup:
                mixed_cubes, labels_a, labels_b, lam = mixup_data(cubes, labels, config["aug_mixup_alpha"], device)
                predictions_logits = model(mixed_cubes)
                loss = mixup_criterion(loss_fn, predictions_logits, labels_a, labels_b, lam)
            else:
                predictions_logits = model(cubes)
                loss = loss_fn(predictions_logits, labels)

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Invalid loss: {loss.item()}. Skipping batch."); continue
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
        
    return epoch_loss / num_batches


def validate_classifier_epoch_refined(model, dataloader, loss_fn, device): # Can reuse prev validate
    model.eval()
    epoch_loss = 0.0; all_preds_probs, all_labels = [], []
    num_batches = len(dataloader)
    if num_batches == 0: return 0.0, 0.0
    
    progress_bar = tqdm(dataloader, desc="RefClf Validating", leave=False)
    with torch.no_grad():
        for cubes, labels, _ in progress_bar:
            cubes, labels = cubes.to(device), labels.to(device).unsqueeze(1)
            with autocast(enabled=torch.cuda.is_available()):
                predictions_logits = model(cubes)
                loss = loss_fn(predictions_logits, labels)
            if torch.isnan(loss) or torch.isinf(loss): print(f"Invalid val_loss: {loss.item()}."); continue
            epoch_loss += loss.item()
            all_preds_probs.extend(torch.sigmoid(predictions_logits).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            progress_bar.set_postfix(loss=loss.item())

    all_labels_np = np.array(all_labels).flatten()
    all_preds_probs_np = np.array(all_preds_probs).flatten()
    val_auc = 0.0
    if len(np.unique(all_labels_np)) > 1 and len(all_labels_np) > 0:
        val_auc = roc_auc_score(all_labels_np, all_preds_probs_np)
    return epoch_loss / num_batches, val_auc


# --- Initialize Refined Classifier Model ---
ref_clf_model = AdvancedHybridNetRefined(CONFIG).to(DEVICE)
print(f"Refined Classifier Model Instantiated. Params: {sum(p.numel() for p in ref_clf_model.parameters() if p.requires_grad):,}")

ref_clf_optimizer = optim.AdamW(ref_clf_model.parameters(), lr=CONFIG["clf_lr_initial"], weight_decay=CONFIG["clf_weight_decay"])
ref_clf_loss_fn = FocalLoss(alpha=CONFIG["clf_focal_loss_alpha"], gamma=CONFIG["clf_focal_loss_gamma"])
ref_clf_scaler = GradScaler(enabled=torch.cuda.is_available())
ref_clf_lr_scheduler = CosineAnnealingLR(ref_clf_optimizer, T_max=CONFIG["clf_cosine_lr_t_max"])

best_val_auc_ref_clf = -1.0
epochs_no_improve_ref_clf = 0

# --- Training Loop (Refined) ---
print(f"\n--- Training Refined Classifier Model ({CONFIG['clf_model_name']}) ---")
for epoch in range(CONFIG["clf_epochs"]):
    # HNM Logic (Simplified placeholder - use full HNM scoring logic as needed)
    current_negatives_for_epoch_ref = []
    if epoch < CONFIG["hnm_start_epoch"] or not dummy_neg_train_pool_ref:
        num_to_sample = min(len(dummy_neg_train_pool_ref), len(dummy_pos_train_ref) * CONFIG["hnm_ratio_neg_to_pos"])
        current_negatives_for_epoch_ref = random.sample(dummy_neg_train_pool_ref, num_to_sample) if dummy_neg_train_pool_ref else []
    else:
        print(f"Epoch {epoch+1}: HNM Placeholder - using random subset of negatives.") # Replace with actual HNM
        num_hard_neg = min(len(dummy_neg_train_pool_ref), len(dummy_pos_train_ref) * CONFIG["hnm_ratio_neg_to_pos"])
        current_negatives_for_epoch_ref = random.sample(dummy_neg_train_pool_ref, num_hard_neg) if dummy_neg_train_pool_ref else []


    ref_clf_train_dataset = RefinedClassifierDataset(dummy_pos_train_ref, current_negatives_for_epoch_ref, epoch, CONFIG, is_train=True)
    ref_clf_train_loader = DataLoader(ref_clf_train_dataset, batch_size=CONFIG["clf_batch_size"], shuffle=True, num_workers=0)

    num_val_neg_to_sample_ref = min(len(dummy_neg_val_pool_ref), len(dummy_pos_val_ref) * CONFIG["hnm_ratio_neg_to_pos"])
    val_neg_samples_ref = random.sample(dummy_neg_val_pool_ref, num_val_neg_to_sample_ref) if dummy_neg_val_pool_ref else []
    ref_clf_val_dataset = RefinedClassifierDataset(dummy_pos_val_ref, val_neg_samples_ref, epoch, CONFIG, is_train=False)
    ref_clf_val_loader = DataLoader(ref_clf_val_dataset, batch_size=CONFIG["clf_batch_size"], shuffle=False, num_workers=0)

    if len(ref_clf_train_loader.dataset) == 0: print(f"Skipping train epoch {epoch+1}, empty train_loader."); ref_clf_lr_scheduler.step(); continue
    
    train_loss_ref_clf = train_classifier_epoch_refined(ref_clf_model, ref_clf_train_loader, ref_clf_optimizer, ref_clf_loss_fn, ref_clf_scaler, DEVICE, epoch, CONFIG)
    
    val_loss_ref_clf, val_auc_ref_clf = 0.0, 0.0
    if len(ref_clf_val_loader.dataset) > 0:
        val_loss_ref_clf, val_auc_ref_clf = validate_classifier_epoch_refined(ref_clf_model, ref_clf_val_loader, ref_clf_loss_fn, DEVICE)
    else: print(f"Skipping val epoch {epoch+1}, empty val_loader.")

    ref_clf_lr_scheduler.step()

    print(f"Epoch {epoch+1}/{CONFIG['clf_epochs']}: RefClf Train Loss: {train_loss_ref_clf:.4f}, RefClf Val Loss: {val_loss_ref_clf:.4f}, RefClf Val AUC: {val_auc_ref_clf:.4f}")

    if val_auc_ref_clf > best_val_auc_ref_clf and not (np.isinf(val_loss_ref_clf) or np.isnan(val_loss_ref_clf)):
        best_val_auc_ref_clf = val_auc_ref_clf
        save_path = CONFIG["output_dir"] / "classification_models_refined" / f"{CONFIG['clf_model_name']}_best.pth"
        torch.save(ref_clf_model.state_dict(), save_path)
        print(f"  Saved best refined classifier model to {save_path} (Val AUC: {best_val_auc_ref_clf:.4f})")
        epochs_no_improve_ref_clf = 0
    else:
        epochs_no_improve_ref_clf += 1

    if epochs_no_improve_ref_clf >= CONFIG["clf_early_stopping_patience"]:
        print(f"Early stopping for refined classifier at epoch {epoch+1}.")
        break
print("Refined Classifier training finished.")

# %% [markdown]
# ## 6. Evaluation, Visualization, Reporting (Adapt as before)

# %%
# --- Final Evaluation (using ref_clf_model) ---
# ... Load best ref_clf_model ...
# ... Prepare test RefinedClassifierDataset (using dummy_pos_val_ref etc. for now) ...
# ... Run validate_classifier_epoch_refined on test data ...
# ... Calculate final metrics ...

print("\n--- Evaluation, Visualization, Reporting (Placeholders) ---")
# This part would reuse logic from previous notebook sections, adapting model names and paths.
# Key tasks:
# 1. Load the best `ref_clf_model`.
# 2. Create a test `RefinedClassifierDataset` and `DataLoader`.
# 3. Run `validate_classifier_epoch_refined` to get test loss and AUC.
# 4. Perform full metric calculation (Precision, Recall, F1, ROC, Sens@Spec).
# 5. Generate visualizations (3D overlay, GradCAM for `ref_clf_model`).
# 6. Write the final report detailing `AdvancedHybridNetRefined` performance.

# Example Test Evaluation Snippet (using dummy val data as test for now)
print("\nEvaluating Refined Model on 'Test' Set (using dummy validation data)...")
best_model_path_ref = CONFIG["output_dir"] / "classification_models_refined" / f"{CONFIG['clf_model_name']}_best.pth"
if best_model_path_ref.exists():
    eval_model_ref = AdvancedHybridNetRefined(CONFIG).to(DEVICE)
    eval_model_ref.load_state_dict(torch.load(best_model_path_ref, map_location=DEVICE))
    print(f"Loaded best refined model from {best_model_path_ref}")
    
    # Using val_neg_samples_ref and dummy_pos_val_ref as a proxy for test data
    ref_clf_test_dataset = RefinedClassifierDataset(dummy_pos_val_ref, val_neg_samples_ref, CONFIG['clf_epochs'], CONFIG, is_train=False)
    if len(ref_clf_test_dataset) > 0:
        ref_clf_test_loader = DataLoader(ref_clf_test_dataset, batch_size=CONFIG["clf_batch_size"])
        test_loss_ref, test_auc_ref = validate_classifier_epoch_refined(eval_model_ref, ref_clf_test_loader, ref_clf_loss_fn, DEVICE)
        print(f"Refined Model 'Test' Loss: {test_loss_ref:.4f}, 'Test' AUC: {test_auc_ref:.4f}")
    else:
        print("Refined Model 'Test' set is empty. Skipping evaluation.")
else:
    print("Best refined model not found for evaluation.")


# %% [markdown]
# ## End of Notebook

Using device: cuda
Placeholder: Loading and selecting data...
Selected 100 scans (50 class 0, 50 class 1)
Placeholder: Preprocessing scans...


Preprocessing Simulation: 100%|██████████| 100/100 [00:00<00:00, 14248.89it/s]

Successfully preprocessed/found 100 scans.
Placeholder: Creating dataloaders...
Created DataLoaders: Train batches=40, Val batches=10





Calculated pos_weight for BCEWithLogitsLoss: 1.00

Epoch 1/10


Training:   0%|          | 0/40 [00:00<?, ?it/s]

Shape mismatch for 1fdbc07019192de4a114e090389c8330: Expected (32, 32, 32), got (64, 64, 64). Attempting resize.
Shape mismatch for 4b351d0c19be183cc880f5af3fe5abee: Expected (32, 32, 32), got (64, 64, 64). Attempting resize.


                                                


--- An error occurred during execution ---


Traceback (most recent call last):
  File "C:\Users\rouaa\AppData\Local\Temp\ipykernel_21784\2919523034.py", line 986, in <module>
    main()
  File "C:\Users\rouaa\AppData\Local\Temp\ipykernel_21784\2919523034.py", line 822, in main
    train_loss_cls, train_loss_seg, train_loss_total, train_acc = train_one_epoch(
                                                                  ^^^^^^^^^^^^^^^^
  File "C:\Users\rouaa\AppData\Local\Temp\ipykernel_21784\2919523034.py", line 624, in train_one_epoch
    loss_seg = criterion_seg(seg_outputs_probs, masks)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\rouaa\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\rouaa\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_ca