In [None]:
# %% Configuration (SET FORCE_REPROCESS = True FOR ONE RUN)
class Config:
    """Configuration class for all hyperparameters and paths."""
    DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Stages"
    DSB_LABELS_CSV = r"C:\Users\rouaa\Documents\Final_Pneumatect\stage1_labels.csv"
    PREPROCESSED_DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\PreProCessing"
    MODEL_OUTPUT_DIR = r"C:\Users\rouaa\Documents\Final_Pneumatect\Models"

    TARGET_SPACING = [1.5, 1.5, 1.5]
    FINAL_SCAN_SIZE = (32, 32, 32) # Keep reduced resolution
    CLIP_BOUND_HU = [-1000.0, 400.0]
    PIXEL_MEAN = 0.25

    SCAN_LIMIT_PER_CLASS = 50
    SEED = 42
    # --- SET TO TRUE FOR ONE RUN ---
    FORCE_REPROCESS = True
    # --- SET TO TRUE FOR ONE RUN ---

    BATCH_SIZE = 1 # Keep 1 for CPU
    NUM_CLASSES = 1
    LEARNING_RATE = 1e-4
    EPOCHS = 10
    NUM_WORKERS = 0

    UNETPP_INITIAL_FILTERS = 8
    UNETPP_DEPTH = 2
    UNETPP_TRANSFORMER_EMBED_DIM = 64
    UNETPP_TRANSFORMER_LAYERS = 1
    UNETPP_TRANSFORMER_HEADS = 4
    UNETPP_FINAL_FC_UNITS = 32
    STOCHASTIC_DEPTH_RATE = 0.1
    CLASSIFICATION_LOSS_WEIGHT = 1.0
    SEGMENTATION_LOSS_WEIGHT = 0.5



# Set random seeds
random.seed(Config.SEED)
np.random.seed(Config.SEED)
torch.manual_seed(Config.SEED)
# No GPU-specific seeds needed

# --- FORCE CPU USAGE ---
DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}") # Will now print "Using device: cpu"
# --- FORCE CPU USAGE ---

os.makedirs(Config.PREPROCESSED_DSB_PATH, exist_ok=True)
os.makedirs(Config.MODEL_OUTPUT_DIR, exist_ok=True)

# %% Placeholder Data Functions (Ensure they use Config.FINAL_SCAN_SIZE)
def preprocess_scans(scan_ids):
    """Placeholder for preprocessing logic."""
    print("Placeholder: Preprocessing scans...")
    successful_ids = []
    os.makedirs(Config.PREPROCESSED_DSB_PATH, exist_ok=True)
    target_size = Config.FINAL_SCAN_SIZE # Use size from Config
    for patient_id in tqdm(scan_ids, desc="Preprocessing Simulation"):
        out_path = os.path.join(Config.PREPROCESSED_DSB_PATH, f"{patient_id}.npz")
        if not os.path.exists(out_path) or Config.FORCE_REPROCESS:
            if Config.FORCE_REPROCESS:
                 print(f"Reprocessing {patient_id} to size {target_size}...")
            else:
                 # Only print if file truly doesn't exist when not forcing
                 if not os.path.exists(out_path):
                     print(f"Simulating creation for {patient_id} (file not found)...")

            # Generate dummy data only if needed
            if not os.path.exists(out_path) or Config.FORCE_REPROCESS:
                dummy_image = np.random.rand(*target_size) * (Config.CLIP_BOUND_HU[1] - Config.CLIP_BOUND_HU[0]) + Config.CLIP_BOUND_HU[0]
                dummy_image = (dummy_image - Config.CLIP_BOUND_HU[0]) / (Config.CLIP_BOUND_HU[1] - Config.CLIP_BOUND_HU[0]) # Normalize 0-1
                dummy_image -= Config.PIXEL_MEAN # Zero center around PIXEL_MEAN
                np.savez_compressed(out_path, image=dummy_image.astype(np.float32))

        # Assume success if file exists after simulated processing
        if os.path.exists(out_path):
           successful_ids.append(patient_id)
        else:
            print(f"Warning: Failed to preprocess or find {patient_id}")

    if Config.FORCE_REPROCESS:
        print("\nIMPORTANT: FORCE_REPROCESS was True. Ensure it is False for next runs.\n")

    print(f"Successfully preprocessed/found {len(successful_ids)} scans.")
    return successful_ids

def load_and_select_data():
    """Placeholder for data loading and selection logic."""
    print("Placeholder: Loading and selecting data...")
    try:
        df_labels = pd.read_csv(Config.DSB_LABELS_CSV)
    except FileNotFoundError:
        print(f"ERROR: Labels CSV not found at {Config.DSB_LABELS_CSV}")
        return [], {}
    all_patient_ids = df_labels['id'].tolist()
    patient_labels = df_labels.set_index('id')['cancer'].to_dict()

    ids_class_0 = [pid for pid in all_patient_ids if patient_labels.get(pid) == 0]
    ids_class_1 = [pid for pid in all_patient_ids if patient_labels.get(pid) == 1]
    random.shuffle(ids_class_0)
    random.shuffle(ids_class_1)
    selected_ids = ids_class_0[:Config.SCAN_LIMIT_PER_CLASS] + ids_class_1[:Config.SCAN_LIMIT_PER_CLASS]
    scans_to_process = selected_ids

    print(f"Selected {len(scans_to_process)} scans ({len(ids_class_0[:Config.SCAN_LIMIT_PER_CLASS])} class 0, {len(ids_class_1[:Config.SCAN_LIMIT_PER_CLASS])} class 1)")
    return scans_to_process, patient_labels

def create_dataloaders(successful_ids, patient_labels):
    """Placeholder for creating dataloaders."""
    print("Placeholder: Creating dataloaders...")
    if not successful_ids:
        raise ValueError("No successfully processed patient IDs found.")

    valid_patient_labels = {pid: label for pid, label in patient_labels.items() if pid in successful_ids}
    stratify_labels = [valid_patient_labels.get(pid, -1) for pid in successful_ids]

    filtered_ids = [pid for pid, label in zip(successful_ids, stratify_labels) if label != -1]
    filtered_stratify_labels = [label for label in stratify_labels if label != -1]

    if not filtered_ids:
         raise ValueError("No patient IDs with valid labels found for splitting.")

    try:
        train_ids, val_ids = train_test_split(
            filtered_ids, test_size=0.2, random_state=Config.SEED,
            stratify=filtered_stratify_labels
        )
    except ValueError as e:
         print(f"Stratified split failed: {e}. Falling back to non-stratified split.")
         train_ids, val_ids = train_test_split(
             filtered_ids, test_size=0.2, random_state=Config.SEED)


    train_dataset = PatientLevelDataset(train_ids, patient_labels, Config.PREPROCESSED_DSB_PATH, mask_path=Config.PREPROCESSED_DSB_PATH)
    val_dataset = PatientLevelDataset(val_ids, patient_labels, Config.PREPROCESSED_DSB_PATH, mask_path=Config.PREPROCESSED_DSB_PATH)

    train_loader = DataLoader(
        train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=False # Set pin_memory to False for CPU
    )
    val_loader = DataLoader(
        val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=False # Set pin_memory to False for CPU
    )
    print(f"Created DataLoaders: Train batches={len(train_loader)}, Val batches={len(val_loader)}")
    return train_loader, val_loader, train_ids, val_ids

# %% Model Definition (remains the same)
class SEBlock3D(nn.Module):
    """3D Squeeze-and-Excitation Block."""
    def __init__(self, channels, reduction=16):
        super(SEBlock3D, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool3d(1)
        rd = max(1, channels // reduction)
        self.excitation = nn.Sequential(
            nn.Conv3d(channels, rd, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(rd, channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.squeeze(x)
        y = self.excitation(y)
        return x * y

class ConvBlock3D(nn.Module):
    """Double Conv -> BN -> ReLU with SE."""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, use_se=True, stochastic_depth_rate=0.0):
        super(ConvBlock3D, self).__init__()
        self.use_se = use_se
        self.stochastic_depth_rate = stochastic_depth_rate
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.se = SEBlock3D(out_channels, reduction=max(1, out_channels//16)) if use_se else nn.Identity()
        self.dropout = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity()

    def forward(self, x):
        identity = x
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.se(x)
        if self.training and self.stochastic_depth_rate > 0:
             x = self.dropout(x)
        return x

class EfficientTransformerLayer3D(nn.Module):
    """Efficient Transformer Layer with MultiHead Attention."""
    def __init__(self, embed_dim, num_heads, ff_dim_factor=4, dropout=0.1, stochastic_depth_rate=0.0):
        super(EfficientTransformerLayer3D, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
             raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")

        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)

        ff_dim = embed_dim * ff_dim_factor
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.stochastic_depth = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity()

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        identity1 = src
        src_norm = self.norm1(src)
        attn_output, attn_weights = self.attn(src_norm, src_norm, src_norm, attn_mask=src_mask,
                                            key_padding_mask=src_key_padding_mask,
                                            need_weights=True) # Keep True for visualization
        src = identity1 + self.dropout1(attn_output)

        identity2 = src
        src_norm = self.norm2(src)
        ffn_output = self.ffn(src_norm)
        src = identity2 + self.dropout2(ffn_output)

        if self.training and isinstance(self.stochastic_depth, nn.Dropout):
             src = self.stochastic_depth(src)

        return src, attn_weights

class TransformerBlock(nn.Module):
    """Transformer Block for Multi-Scale Integration."""
    def __init__(self, in_channels, embed_dim, num_layers, num_heads, spatial_dims, dropout=0.1, stochastic_depth_rate=0.0):
        super(TransformerBlock, self).__init__()
        self.patch_projection = nn.Conv3d(in_channels, embed_dim, kernel_size=1) if in_channels != embed_dim else nn.Identity()
        if not spatial_dims or any(s <= 0 for s in spatial_dims):
             raise ValueError(f"Invalid spatial_dims for TransformerBlock: {spatial_dims}")
        num_patches = np.prod(spatial_dims)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
        self.dropout_pos = nn.Dropout(dropout)

        sd_rates = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
        self.transformer_layers = nn.ModuleList([
            EfficientTransformerLayer3D(embed_dim, num_heads, dropout=dropout,
                                     stochastic_depth_rate=sd_rates[i])
            for i in range(num_layers)
        ])
        self.norm_out = nn.LayerNorm(embed_dim)
        self.out_projection = nn.Conv3d(embed_dim, in_channels, kernel_size=1) if in_channels != embed_dim else nn.Identity()
        self.spatial_dims = spatial_dims

    def forward(self, x):
        b, c, d, h, w = x.shape
        if (d, h, w) != self.spatial_dims:
             print(f"Warning: Input spatial dimensions {d,h,w} differ from expected {self.spatial_dims} in TransformerBlock.")
             # Consider adaptive pooling if this becomes an issue:
             # x = F.adaptive_avg_pool3d(x, self.spatial_dims)

        x_proj = self.patch_projection(x)
        embed_dim = x_proj.shape[1]
        x_flat = x_proj.flatten(2).transpose(1, 2) # Shape: (b, num_patches, embed_dim)

        if x_flat.shape[1] != self.pos_embed.shape[1]:
             print(f"Warning: Patch count {x_flat.shape[1]} mismatch with pos embed {self.pos_embed.shape[1]}. Skipping pos embed.")
             x_processed = x_flat
        else:
             x_processed = x_flat + self.pos_embed

        x_processed = self.dropout_pos(x_processed)

        attn_weights = []
        for layer in self.transformer_layers:
            x_processed, attn = layer(x_processed)
            attn_weights.append(attn)

        x_processed = self.norm_out(x_processed)
        x_reshaped = x_processed.transpose(1, 2).view(b, embed_dim, *self.spatial_dims)
        x_out = self.out_projection(x_reshaped)
        return x_out, attn_weights


class EfficientUNetPlusPlus_SE_Transformer(nn.Module):
    """Efficient U-Net++ with SE and Multi-Scale Transformers (Transformer skipped at level 0)."""
    def __init__(self, input_scan_size=Config.FINAL_SCAN_SIZE, # Use Config default
                 in_channels=1, num_classes=1,
                 initial_filters=Config.UNETPP_INITIAL_FILTERS, depth=Config.UNETPP_DEPTH,
                 use_se=True, transformer_embed_dim=Config.UNETPP_TRANSFORMER_EMBED_DIM,
                 transformer_layers=Config.UNETPP_TRANSFORMER_LAYERS,
                 transformer_heads=Config.UNETPP_TRANSFORMER_HEADS, transformer_dropout=0.1,
                 final_fc_units=Config.UNETPP_FINAL_FC_UNITS,
                 stochastic_depth_rate=Config.STOCHASTIC_DEPTH_RATE):
        super(EfficientUNetPlusPlus_SE_Transformer, self).__init__()
        self.depth = depth
        if not isinstance(input_scan_size, tuple) or len(input_scan_size) != 3:
             raise ValueError("input_scan_size must be a tuple of 3 integers (D, H, W)")
        self.input_scan_size = input_scan_size
        nf = initial_filters

        # Encoder
        self.encoder_blocks = nn.ModuleList()
        self.transformer_blocks = nn.ModuleList() # Will hold Transformers or Identity
        self.pools = nn.ModuleList()
        encoder_output_channels = []
        current_channels = in_channels

        for i in range(depth + 1):
            out_ch = nf * (2**i)
            embed_dim_for_layer = transformer_embed_dim
            if i > 0 and embed_dim_for_layer % transformer_heads != 0:
                 adjusted_embed_dim = (embed_dim_for_layer // transformer_heads) * transformer_heads
                 if adjusted_embed_dim == 0: adjusted_embed_dim = transformer_heads
                 print(f"Warning: Level {i} embed_dim {embed_dim_for_layer} adjusted to {adjusted_embed_dim} for heads {transformer_heads}.")
                 embed_dim_for_layer = adjusted_embed_dim

            self.encoder_blocks.append(ConvBlock3D(
                current_channels, out_ch, use_se=use_se,
                stochastic_depth_rate=stochastic_depth_rate
            ))

            spatial_dims = tuple(s // (2**i) for s in input_scan_size)
            if any(s <= 0 for s in spatial_dims):
                 raise ValueError(f"Calculated spatial dimensions {spatial_dims} at depth {i} are invalid.")

            # Apply Transformer only for levels i > 0
            if i > 0:
                 self.transformer_blocks.append(TransformerBlock(
                     in_channels=out_ch,
                     embed_dim=embed_dim_for_layer,
                     num_layers=transformer_layers,
                     num_heads=transformer_heads,
                     spatial_dims=spatial_dims,
                     dropout=transformer_dropout,
                     stochastic_depth_rate=stochastic_depth_rate
                 ))
            else:
                 self.transformer_blocks.append(nn.Identity()) # Skip transformer at full resolution

            encoder_output_channels.append(out_ch)
            if i < depth:
                self.pools.append(nn.MaxPool3d(2, 2))
            current_channels = out_ch

        # Decoder
        self.decoder_conv_modulelist = nn.ModuleList()
        self.upsamplers = nn.ModuleList()
        for i in range(depth):
             ch_from_below = encoder_output_channels[i+1]
             ch_to_current_level_filters = encoder_output_channels[i]
             self.upsamplers.append(
                 nn.ConvTranspose3d(ch_from_below, ch_to_current_level_filters, kernel_size=2, stride=2)
             )

        for i in range(depth):
            level_i_decoder_blocks = nn.ModuleList()
            for j in range(1, depth - i + 1):
                in_ch_Xij = encoder_output_channels[i] * j + encoder_output_channels[i]
                out_ch_Xij = encoder_output_channels[i]
                level_i_decoder_blocks.append(ConvBlock3D(
                    in_ch_Xij, out_ch_Xij, use_se=use_se,
                    stochastic_depth_rate=stochastic_depth_rate
                ))
            self.decoder_conv_modulelist.append(level_i_decoder_blocks)

        # Classification Head
        final_decoder_output_channels = encoder_output_channels[0]
        self.classification_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(),
            nn.Linear(final_decoder_output_channels, final_fc_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(final_fc_units, num_classes)
        )

        # Segmentation Head (outputs logits)
        self.segmentation_head = nn.Sequential(
            nn.Conv3d(final_decoder_output_channels, 1, kernel_size=1),
            nn.Identity() # Output logits
        )

    def forward(self, x):
        if x.shape[2:] != self.input_scan_size:
             print(f"Warning: Input scan size {x.shape[2:]} does not match model expected size {self.input_scan_size}. Interpolating.")
             x = F.interpolate(x, size=self.input_scan_size, mode='trilinear', align_corners=False)

        X_features = [[None] * (self.depth + 1) for _ in range(self.depth + 1)]
        attention_maps = [[] for _ in range(self.depth + 1)] # Store list of lists

        # Encoder path
        current = x
        for i in range(self.depth + 1):
            conv_out = self.encoder_blocks[i](current)
            transformer_module = self.transformer_blocks[i]

            # Apply transformer or identity
            if isinstance(transformer_module, TransformerBlock):
                trans_out, attn_list = transformer_module(conv_out)
                X_features[i][0] = trans_out
                if attn_list: # attn_list is a list of weights per layer
                    attention_maps[i].extend(attn_list)
            else: # nn.Identity
                X_features[i][0] = transformer_module(conv_out)
                # attention_maps[i] remains empty list []

            if i < self.depth:
                current = self.pools[i](X_features[i][0])

        # Decoder path
        for j in range(1, self.depth + 1):
             for i in range(self.depth - j + 1):
                 inputs_same_level = [X_features[i][k] for k in range(j)]
                 upsampled_input = self.upsamplers[i](X_features[i+1][j-1])
                 target_spatial = X_features[i][0].shape[2:]
                 if upsampled_input.shape[2:] != target_spatial:
                     upsampled_input = F.interpolate(
                         upsampled_input, size=target_spatial,
                         mode='trilinear', align_corners=False
                     )
                 combined = torch.cat(inputs_same_level + [upsampled_input], dim=1)
                 X_features[i][j] = self.decoder_conv_modulelist[i][j-1](combined)

        # Final Outputs
        final_decoder_output = X_features[0][self.depth]
        classification_logits = self.classification_head(final_decoder_output)
        segmentation_logits = self.segmentation_head(final_decoder_output) # Get logits

        # Return classification logits, segmentation LOGITS, and attention maps
        return classification_logits, segmentation_logits, attention_maps

    def get_attention_maps(self, x):
        """Extract attention maps (first layer only per block) for visualization."""
        self.eval()
        with torch.no_grad():
            if x.shape[2:] != self.input_scan_size:
                 x = F.interpolate(x, size=self.input_scan_size, mode='trilinear', align_corners=False)
            # We only need the third return value (attention maps)
            _, _, attention_maps_nested = self.forward(x)

        # Filter out empty lists (from level 0) and get the first layer's map from others
        flat_attention_maps = []
        for level_maps in attention_maps_nested: # Iterate through levels
             if level_maps: # Check if list for this level is not empty (i.e., transformer was applied)
                 # level_maps is list of attention tensors, one per transformer layer
                 # Take the first layer's attention map for simplicity
                 first_layer_attn = level_maps[0] # Shape (batch, heads, seq, seq)
                 flat_attention_maps.append(first_layer_attn)
        return flat_attention_maps


# %% Dataset
class PatientLevelDataset(Dataset):
    """Dataset with classification labels and segmentation masks."""
    def __init__(self, patient_ids, labels_dict, preprocessed_path, mask_path=None):
        self.patient_ids = patient_ids
        self.labels_dict = labels_dict
        self.preprocessed_path = preprocessed_path
        self.mask_path = mask_path if mask_path else preprocessed_path
        self.target_size = Config.FINAL_SCAN_SIZE # Use size from Config

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        label = self.labels_dict.get(patient_id, -1)
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")

        error_image = torch.zeros((1, *self.target_size), dtype=torch.float32)
        error_label = torch.tensor(-1, dtype=torch.float32)
        error_mask = torch.zeros((1, *self.target_size), dtype=torch.float32)

        if label == -1:
            return error_image, error_label, error_mask

        try:
            if not os.path.exists(scan_path):
                 # print(f"Error: Preprocessed file not found: {scan_path}") # Reduce noise
                 return error_image, error_label, error_mask

            with np.load(scan_path) as npz_data:
                if 'image' not in npz_data:
                     print(f"Error: 'image' key not found in {scan_path}")
                     return error_image, error_label, error_mask
                image = npz_data['image']

            if image.shape != self.target_size:
                print(f"Shape mismatch for {patient_id}: Expected {self.target_size}, got {image.shape}. Attempting resize.")
                try:
                     zoom_factors = tuple(t / s for t, s in zip(self.target_size, image.shape))
                     image_resized = ndimage.zoom(image, zoom_factors, order=1) # order=1 for image
                     if image_resized.shape != self.target_size:
                          print(f"Error: Resizing failed for {patient_id}. Got shape {image_resized.shape}")
                          return error_image, error_label, error_mask
                     image = image_resized
                except Exception as resize_e:
                    print(f"Error during image resizing for {patient_id}: {resize_e}")
                    return error_image, error_label, error_mask

            image_tensor = torch.from_numpy(image).float().unsqueeze(0)
            label_tensor = torch.tensor(label, dtype=torch.float32)

            # Load segmentation mask
            mask_tensor = torch.zeros((1, *self.target_size), dtype=torch.float32)
            mask_file = os.path.join(self.mask_path, f"{patient_id}_mask.npz")
            if os.path.exists(mask_file):
                try:
                    with np.load(mask_file) as mask_data:
                        if 'mask' in mask_data:
                             mask = mask_data['mask']
                             if mask.shape == self.target_size:
                                 mask_tensor = torch.from_numpy(mask).float().unsqueeze(0)
                             else:
                                 # print(f"Warning: Mask shape mismatch for {patient_id}. Expected {self.target_size}, got {mask.shape}. Attempting resize.")
                                 try:
                                      zoom_factors = tuple(t / s for t, s in zip(self.target_size, mask.shape))
                                      mask_resized = ndimage.zoom(mask, zoom_factors, order=0) # order=0 for mask
                                      if mask_resized.shape == self.target_size:
                                           mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0)
                                      # else: # Reduce noise
                                      #      print(f"Warning: Mask resizing failed for {patient_id}. Using zero mask.")
                                 except Exception as resize_mask_e:
                                     print(f"Error during mask resizing for {patient_id}: {resize_mask_e}. Using zero mask.")
                except Exception as e_mask:
                    print(f"Error loading mask for {patient_id}: {e_mask}. Using zero mask.")

            mask_tensor = mask_tensor.float()
            return image_tensor, label_tensor, mask_tensor

        except FileNotFoundError:
             # print(f"Error: File not found for {patient_id} at {scan_path}") # Reduce noise
             return error_image, error_label, error_mask
        except Exception as e:
            print(f"Error loading data for {patient_id}: {e}")
            import traceback
            traceback.print_exc() # Print detailed error
            return error_image, error_label, error_mask


# %% Training and Validation Functions (CPU VERSION - NO AMP/SCALER)
def train_one_epoch(model, dataloader, criterion_cls, criterion_seg, optimizer, device): # Removed scaler
    model.train()
    running_loss_cls = 0.0
    running_loss_seg = 0.0
    total_loss_weighted = 0.0
    total_samples = 0
    correct_predictions = 0

    for inputs, labels, masks in tqdm(dataloader, desc="Training (CPU)", leave=False): # Indicate CPU
        valid_indices = labels != -1
        if not torch.any(valid_indices):
            continue

        # Ensure batch dimension exists even if batch size is 1
        inputs = inputs[valid_indices].to(device)
        labels = labels[valid_indices].unsqueeze(1).to(device) # Ensure (N, 1) shape
        masks = masks[valid_indices].to(device)

        current_batch_size = inputs.size(0)
        if current_batch_size == 0: continue
        total_samples += current_batch_size

        optimizer.zero_grad(set_to_none=True)

        # --- NO torch.amp.autocast ---
        # Model returns logits for both outputs
        cls_logits, seg_logits, _ = model(inputs)

        # Classification loss (expects logits)
        loss_cls = criterion_cls(cls_logits, labels)

        # Segmentation loss (expects logits)
        loss_seg = criterion_seg(seg_logits, masks)

        # Combined weighted loss
        loss = Config.CLASSIFICATION_LOSS_WEIGHT * loss_cls + Config.SEGMENTATION_LOSS_WEIGHT * loss_seg
        # --- NO torch.amp.autocast ---

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: NaN/Inf loss detected (Cls: {loss_cls.item():.4f}, Seg: {loss_seg.item():.4f}). Skipping batch.")
            continue

        # --- Standard backward pass (NO SCALER) ---
        loss.backward()
        optimizer.step()
        # --- Standard backward pass (NO SCALER) ---

        running_loss_cls += loss_cls.item() * current_batch_size
        running_loss_seg += loss_seg.item() * current_batch_size
        total_loss_weighted += loss.item() * current_batch_size

        preds = torch.sigmoid(cls_logits) > 0.5
        correct_predictions += (preds == labels.bool()).sum().item()

    if total_samples == 0:
        print("Warning: No valid samples processed in training epoch.")
        return 0.0, 0.0, 0.0, 0.0

    avg_loss_cls = running_loss_cls / total_samples
    avg_loss_seg = running_loss_seg / total_samples
    avg_loss_total = total_loss_weighted / total_samples
    avg_acc = correct_predictions / total_samples

    return avg_loss_cls, avg_loss_seg, avg_loss_total, avg_acc


def validate(model, dataloader, criterion_cls, criterion_seg, device): # Removed scaler
    model.eval()
    running_loss_cls = 0.0
    running_loss_seg = 0.0
    total_loss_weighted = 0.0
    total_samples = 0
    all_preds_proba = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, masks in tqdm(dataloader, desc="Validating (CPU)", leave=False): # Indicate CPU
            valid_indices = labels != -1
            if not torch.any(valid_indices):
                continue

            inputs = inputs[valid_indices].to(device)
            labels = labels[valid_indices].to(device) # Keep 1D for metrics later
            masks = masks[valid_indices].to(device)

            current_batch_size = inputs.size(0)
            if current_batch_size == 0: continue
            total_samples += current_batch_size

            # --- NO torch.amp.autocast ---
            cls_logits, seg_logits, _ = model(inputs)

            # Classification loss expects logits and (N, 1) labels
            loss_cls = criterion_cls(cls_logits, labels.unsqueeze(1))

            # Segmentation loss expects logits
            loss_seg = criterion_seg(seg_logits, masks)

            loss = Config.CLASSIFICATION_LOSS_WEIGHT * loss_cls + Config.SEGMENTATION_LOSS_WEIGHT * loss_seg
            # --- NO torch.amp.autocast ---

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: NaN/Inf validation loss detected (Cls: {loss_cls.item():.4f}, Seg: {loss_seg.item():.4f}). Skipping batch.")
                continue

            running_loss_cls += loss_cls.item() * current_batch_size
            running_loss_seg += loss_seg.item() * current_batch_size
            total_loss_weighted += loss.item() * current_batch_size

            all_preds_proba.extend(torch.sigmoid(cls_logits).cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    if total_samples == 0:
        print("Warning: No valid samples processed in validation.")
        return 0.0, 0.0, 0.0, np.array([]), np.array([])

    avg_loss_cls = running_loss_cls / total_samples
    avg_loss_seg = running_loss_seg / total_samples
    avg_loss_total = total_loss_weighted / total_samples

    return avg_loss_cls, avg_loss_seg, avg_loss_total, np.array(all_labels), np.array(all_preds_proba)

# %% Attention Visualization (remains the same, uses device="cpu")
def visualize_attention_maps(model, input_tensor, save_path, device):
    """Visualize Transformer attention maps (first layer only per block)."""
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    model.eval()
    with torch.no_grad():
        # Model and tensor are on CPU, unsqueeze adds batch dim
        input_tensor = input_tensor.unsqueeze(0).to(device)
        attention_maps_list = model.get_attention_maps(input_tensor)

    if not attention_maps_list:
         print("No attention maps returned by the model.")
         return

    print(f"Visualizing {len(attention_maps_list)} attention map(s)...")
    for map_idx, attn_map_layer in enumerate(attention_maps_list):
        # Level 0 was skipped, so first map (map_idx 0) is from level 1
        encoder_level = map_idx + 1

        if attn_map_layer is None or attn_map_layer.numel() == 0:
            print(f"Skipping empty attention map for Encoder Level {encoder_level}")
            continue

        # Data is already on CPU
        attn_map_numpy = attn_map_layer.squeeze(0).numpy() # Shape: (num_heads, seq_len, seq_len)

        # Average across heads: (num_heads, seq_len, seq_len) -> (seq_len, seq_len)
        mean_attn = np.mean(attn_map_numpy, axis=0)

        plt.figure(figsize=(8, 8))
        im = plt.imshow(mean_attn, cmap='viridis', aspect='auto')
        plt.colorbar(im)
        plt.title(f'Mean Attention Map - Encoder Level {encoder_level} (CPU)')
        plt.xlabel('Key Positions (Flattened)')
        plt.ylabel('Query Positions (Flattened)')
        plt.savefig(os.path.join(save_path, f'attention_encoder_level_{encoder_level}_cpu.png'))
        plt.close()
    print(f"Attention maps saved to {save_path}")


# %% Main Execution (CPU VERSION)
def main():
    # Load data IDs and labels
    scans_to_process, patient_labels = load_and_select_data()
    if not scans_to_process:
        print("Error: No scans selected. Exiting.")
        return

    # Preprocess scans (ensure they exist with the correct 32x32x32 size)
    successful_ids = preprocess_scans(scans_to_process)
    if not successful_ids:
        print("Error: No scans were successfully preprocessed or found. Exiting.")
        return

    # Create dataloaders (will use pin_memory=False)
    try:
        train_loader, val_loader, train_ids, val_ids = create_dataloaders(
            successful_ids, patient_labels
        )
    except ValueError as e:
        print(f"Error creating dataloaders: {e}. Exiting.")
        return

    # Initialize model and move to CPU
    model = EfficientUNetPlusPlus_SE_Transformer().to(DEVICE) # DEVICE is "cpu"

    # Loss and optimizer (Using BCEWithLogitsLoss for both)
    labels_list = [patient_labels.get(pid) for pid in successful_ids if patient_labels.get(pid) is not None]
    pos_weight_val = 1.0
    if labels_list:
        num_neg = sum(1 for l in labels_list if l == 0)
        num_pos = sum(1 for l in labels_list if l == 1)
        if num_pos > 0:
            pos_weight_val = num_neg / num_pos
    print(f"Calculated pos_weight for BCEWithLogitsLoss: {pos_weight_val:.2f}")

    # Loss functions are defined on CPU by default
    criterion_cls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_val])) # No device needed
    criterion_seg = nn.BCEWithLogitsLoss()

    optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS)
    # --- NO SCALER NEEDED FOR CPU ---

    # Training loop
    best_val_loss = float('inf')
    train_losses_cls, train_losses_seg, train_losses_total = [], [], []
    val_losses_cls, val_losses_seg, val_losses_total = [], [], []
    train_accs, val_accs = [], []
    # Modify save path name to indicate CPU run
    model_save_path = os.path.join(Config.MODEL_OUTPUT_DIR, "efficient_unetpp_se_transformer_cpu.pth")

    print(f"\nStarting training on CPU for {Config.EPOCHS} epochs with Batch Size {Config.BATCH_SIZE}...")
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        start_time = time.time()
        # --- Pass device="cpu", no scaler ---
        train_loss_cls, train_loss_seg, train_loss_total, train_acc = train_one_epoch(
            model, train_loader, criterion_cls, criterion_seg, optimizer, DEVICE
        )
        val_loss_cls, val_loss_seg, val_loss_total, val_labels, val_preds_proba = validate(
            model, val_loader, criterion_cls, criterion_seg, DEVICE
        )
        epoch_time = time.time() - start_time
        scheduler.step()

        # Store metrics
        train_losses_cls.append(train_loss_cls)
        train_losses_seg.append(train_loss_seg)
        train_losses_total.append(train_loss_total)
        train_accs.append(train_acc)
        val_losses_cls.append(val_loss_cls)
        val_losses_seg.append(val_loss_seg)
        val_losses_total.append(val_loss_total)

        val_acc = 0.0
        if len(val_labels) > 0 and len(val_preds_proba) == len(val_labels):
             try:
                 val_preds_binary = (val_preds_proba > 0.5).astype(int)
                 val_acc = accuracy_score(val_labels, val_preds_binary)
             except Exception as e:
                 print(f"Could not calculate validation accuracy: {e}")
        val_accs.append(val_acc)

        print(f"Time: {epoch_time:.2f}s") # Expect this to be much longer on CPU
        print(f"Train Loss (Cls/Seg/Total): {train_loss_cls:.4f} / {train_loss_seg:.4f} / {train_loss_total:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss   (Cls/Seg/Total): {val_loss_cls:.4f} / {val_loss_seg:.4f} / {val_loss_total:.4f},   Val Acc: {val_acc:.4f}")
        print(f"Current LR: {scheduler.get_last_lr()[0]:.6f}")

        # Save best model based on total validation loss
        if not np.isnan(val_loss_total) and not np.isinf(val_loss_total) and val_loss_total < best_val_loss and len(val_labels) > 0:
            best_val_loss = val_loss_total
            try:
                # Ensure model is saved in CPU format if needed elsewhere
                torch.save(model.state_dict(), model_save_path)
                print(f"Saved best model (Val Loss: {best_val_loss:.4f}) to {model_save_path}")
            except Exception as save_e:
                print(f"Error saving model: {save_e}")
        elif np.isnan(val_loss_total) or np.isinf(val_loss_total):
             print(f"Skipping model save due to invalid validation loss: {val_loss_total}")

    # Plot training history
    print("\nPlotting training history...")
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.plot(range(1, Config.EPOCHS + 1), train_losses_cls, label='Train Cls Loss')
    plt.plot(range(1, Config.EPOCHS + 1), val_losses_cls, label='Val Cls Loss')
    plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Classification Loss (CPU)')
    plt.legend(); plt.grid(True)
    plt.subplot(1, 3, 2)
    plt.plot(range(1, Config.EPOCHS + 1), train_losses_seg, label='Train Seg Loss')
    plt.plot(range(1, Config.EPOCHS + 1), val_losses_seg, label='Val Seg Loss')
    plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Segmentation Loss (CPU)')
    plt.legend(); plt.grid(True)
    plt.subplot(1, 3, 3)
    plt.plot(range(1, Config.EPOCHS + 1), train_accs, label='Train Acc')
    plt.plot(range(1, Config.EPOCHS + 1), val_accs, label='Val Acc')
    plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('Accuracy (CPU)')
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plot_save_path = os.path.join(Config.MODEL_OUTPUT_DIR, "training_curves_cpu.png")
    plt.savefig(plot_save_path)
    print(f"Training curves saved to {plot_save_path}")
    plt.close()

    # Final Evaluation
    print("\n--- Final Evaluation on Validation Set using Best Model (CPU) ---")
    if os.path.exists(model_save_path):
        try:
            # Load model weights onto CPU
            model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
            print(f"Loaded best CPU model from {model_save_path}")
        except Exception as e:
            print(f"Error loading best model: {e}. Evaluation might use last epoch's weights.")

        val_loss_cls, val_loss_seg, val_loss_total, val_labels, val_preds_proba = validate(
            model, val_loader, criterion_cls, criterion_seg, DEVICE # DEVICE is 'cpu'
        )

        if len(val_labels) > 0 and len(val_preds_proba) == len(val_labels):
            val_preds_binary = (val_preds_proba > 0.5).astype(int)
            auc_roc = float('nan')
            if len(np.unique(val_labels)) > 1:
                 try: auc_roc = roc_auc_score(val_labels, val_preds_proba)
                 except ValueError as e: print(f"Could not calculate AUC: {e}")

            print(f"\nValidation Loss (Cls/Seg/Total): {val_loss_cls:.4f} / {val_loss_seg:.4f} / {val_loss_total:.4f}")
            print(f"Accuracy:  {accuracy_score(val_labels, val_preds_binary):.4f}")
            print(f"Precision: {precision_score(val_labels, val_preds_binary, zero_division=0):.4f}")
            print(f"Recall:    {recall_score(val_labels, val_preds_binary, zero_division=0):.4f}")
            print(f"F1 Score:  {f1_score(val_labels, val_preds_binary, zero_division=0):.4f}")
            print(f"AUC ROC:   {auc_roc:.4f}")

            print("\nClassification Report:")
            target_names = ['Class 0 (Non-Cancer)', 'Class 1 (Cancer)']
            try: print(classification_report(val_labels, val_preds_binary, target_names=target_names, zero_division=0))
            except Exception as e: print(f"Could not generate classification report: {e}")

            print("\nConfusion Matrix:")
            try:
                cm = confusion_matrix(val_labels, val_preds_binary)
                disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
                disp.plot(cmap=plt.cm.Blues)
                cm_save_path = os.path.join(Config.MODEL_OUTPUT_DIR, "confusion_matrix_cpu.png")
                plt.savefig(cm_save_path)
                print(f"Confusion matrix saved to {cm_save_path}")
                plt.close()
            except Exception as e: print(f"Could not plot confusion matrix: {e}")
        else:
            print("Validation set empty or prediction/label length mismatch. No final metrics.")
    else:
        print(f"Best model file not found at {model_save_path}. Skipping final evaluation.")

    # Attention Map Visualization (will run on CPU)
    print("\n--- Visualizing Attention Maps (CPU) ---")
    if val_loader and len(val_loader.dataset) > 0:
        try:
            input_sample, vis_label, _ = val_loader.dataset[0]
            if vis_label != -1:
                 attention_save_path = os.path.join(Config.MODEL_OUTPUT_DIR, "attention_maps_cpu")
                 visualize_attention_maps(model, input_sample, attention_save_path, DEVICE) # DEVICE is 'cpu'
            else:
                print("First validation sample is invalid, skipping attention visualization.")
        except IndexError:
             print("Validation dataset is empty, cannot visualize attention.")
        except Exception as e:
             print(f"Error during attention visualization: {e}")
             import traceback
             traceback.print_exc()
    else:
        print("Validation loader/dataset not available or empty, skipping attention visualization.")

    print("\n--- Execution Finished (CPU) ---")


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"\n--- A critical error occurred during execution ---")
        import traceback
        traceback.print_exc()