In [29]:
import os
import glob
import time
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
import random
import scipy.ndimage
from skimage.measure import label as skimage_label, regionprops
from skimage.morphology import disk, binary_closing
from skimage.segmentation import clear_border
import scipy.ndimage as ndi

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, precision_recall_curve, auc, f1_score,
                             precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay)

# Configuration
# --- MODIFY THESE PATHS ---
DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Stages"
DSB_LABELS_CSV = r"C:\Users\rouaa\Documents\Final_Pneumatect\stage1_labels.csv"
# --- MODIFIED: New output path for this balanced & attention model ---
PREPROCESSED_DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention"
# ---

# Preprocessing & Model Params
TARGET_SPACING = [1.5, 1.5, 1.5]
FINAL_SCAN_SIZE = (96, 128, 128) # (Depth, Height, Width)
CLIP_BOUND_HU = [-1000.0, 400.0]
PIXEL_MEAN = 0.25

# Training Params
NUM_CLASSES = 1
BATCH_SIZE = 4
LEARNING_RATE = 0.0001
EPOCHS = 50 # Using 50 epochs as per recent request for new model variants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

SCAN_LIMIT_PER_CLASS = 50 # Max scans per class if available

# Ensure output directory exists
os.makedirs(PREPROCESSED_DSB_PATH, exist_ok=True)

# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Data Loading and Selection (MODIFIED FOR BALANCING) ---

print(f"--- Loading Data and Selecting EVEN Number of Scans Per Class (up to {SCAN_LIMIT_PER_CLASS} each) ---")

if not os.path.isdir(DSB_PATH): raise SystemExit(f"ERROR: DSB Scans path not found: {DSB_PATH}")
if not os.path.isfile(DSB_LABELS_CSV): raise SystemExit(f"ERROR: DSB Labels CSV not found: {DSB_LABELS_CSV}")
dsb_labels_df = pd.read_csv(DSB_LABELS_CSV)
dsb_labels_df = dsb_labels_df.rename(columns={'id': 'patient_id'})
patient_labels_all = dsb_labels_df.set_index('patient_id')['cancer'].to_dict()
scan_folders = [f for f in os.listdir(DSB_PATH) if os.path.isdir(os.path.join(DSB_PATH, f))]
found_scan_ids = set(scan_folders)
labeled_patient_ids_all = set(dsb_labels_df['patient_id'])
common_ids_all = labeled_patient_ids_all.intersection(found_scan_ids)

common_ids_cancer_available = [pid for pid in common_ids_all if patient_labels_all.get(pid) == 1]
common_ids_non_cancer_available = [pid for pid in common_ids_all if patient_labels_all.get(pid) == 0]

random.shuffle(common_ids_cancer_available)
random.shuffle(common_ids_non_cancer_available)

# Determine the number of scans to pick per class for balancing
num_to_select_per_class = min(len(common_ids_cancer_available),
                              len(common_ids_non_cancer_available),
                              SCAN_LIMIT_PER_CLASS)

print(f"Available Cancerous scans with labels: {len(common_ids_cancer_available)}")
print(f"Available Non-Cancerous scans with labels: {len(common_ids_non_cancer_available)}")
print(f"Selecting {num_to_select_per_class} scans from each class for balancing.")

selected_cancer_ids = common_ids_cancer_available[:num_to_select_per_class]
selected_non_cancer_ids = common_ids_non_cancer_available[:num_to_select_per_class]

print(f"Selected {len(selected_cancer_ids)} Cancerous scans.")
print(f"Selected {len(selected_non_cancer_ids)} Non-Cancerous scans.")

scans_to_process = selected_cancer_ids + selected_non_cancer_ids
random.shuffle(scans_to_process) # Shuffle combined list

print(f"Total scans selected for preprocessing: {len(scans_to_process)}")
patient_labels = {pid: patient_labels_all[pid] for pid in scans_to_process}


# --- Preprocessing Functions (Identical to previous script) ---
def load_scan_series(dicom_folder_path):
    try:
        series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dicom_folder_path)
        if not series_ids: return None, None, None
        series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dicom_folder_path, series_ids[0])
        series_reader = sitk.ImageSeriesReader(); series_reader.SetFileNames(series_file_names)
        itkimage = series_reader.Execute()
        image_array = sitk.GetArrayFromImage(itkimage); origin = np.array(list(reversed(itkimage.GetOrigin()))); spacing = np.array(list(reversed(itkimage.GetSpacing())))
        return image_array, origin, spacing
    except Exception as e: print(f"Error reading DICOM {os.path.basename(dicom_folder_path)}: {e}"); return None, None, None

def resample(image, original_spacing, new_spacing=TARGET_SPACING):
    try:
        resize_factor = np.array(original_spacing) / np.array(new_spacing)
        new_real_shape = image.shape * resize_factor; new_shape = np.round(new_real_shape)
        real_resize_factor = new_shape / image.shape; actual_new_spacing = original_spacing / real_resize_factor
        resampled_image = scipy.ndimage.zoom(image, real_resize_factor, mode='nearest', order=1)
        return resampled_image, actual_new_spacing
    except Exception as e: print(f"Error resamping: {e}"); return None, None

def get_segmented_lungs(im_slice, hu_threshold=-320):
    if im_slice.ndim != 2: return im_slice # Should be a 2D slice
    binary = im_slice < hu_threshold
    cleared = clear_border(binary)
    label_image = skimage_label(cleared)
    areas = [r.area for r in regionprops(label_image)]
    areas.sort()
    area_threshold = areas[-2] if len(areas) >= 2 else (areas[-1] if len(areas) == 1 else 0)
    if area_threshold > 0:
        for region in regionprops(label_image):
            if region.area < area_threshold:
                for coordinates in region.coords:
                    label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0
    selem = disk(2)
    binary = binary_closing(binary, selem)
    selem_dilate = disk(5)
    final_mask = ndi.binary_dilation(binary, structure=selem_dilate)
    background_val = CLIP_BOUND_HU[0] - 1
    segmented_slice = im_slice.copy()
    segmented_slice[final_mask == 0] = background_val
    return segmented_slice

def normalize_hu(image, clip_bounds=CLIP_BOUND_HU):
    min_bound, max_bound = clip_bounds
    image = np.clip(image, min_bound, max_bound)
    image = (image - min_bound) / (max_bound - min_bound)
    return image.astype(np.float32)

def zero_center(image, pixel_mean=PIXEL_MEAN):
    image = image - pixel_mean
    return image.astype(np.float32)

def resize_scan_to_target(image, target_shape=FINAL_SCAN_SIZE):
    if image.shape == target_shape: return image
    resize_factor = np.array(target_shape) / np.array(image.shape)
    try:
        resized_image = scipy.ndimage.zoom(image, resize_factor, order=1, mode='nearest')
        if resized_image.shape != target_shape:
            current_shape = resized_image.shape
            diff = np.array(target_shape) - np.array(current_shape)
            pad = np.maximum(diff, 0); crop = np.maximum(-diff, 0)
            pad_width = tuple((p // 2, p - p // 2) for p in pad)
            resized_image = np.pad(resized_image, pad_width, mode='edge')
            crop_slice = tuple(slice(c // 2, s - (c - c // 2)) for c, s in zip(crop, resized_image.shape))
            resized_image = resized_image[crop_slice]
        if resized_image.shape != target_shape: print(f"ERROR: Resize failed. Shape {resized_image.shape} vs Target {target_shape}"); return None
        return resized_image.astype(np.float32)
    except Exception as e: print(f"Error resizing to target: {e}"); return None

def preprocess_scan_dsb(patient_id, input_base_path, output_base_path, force_preprocess=False):
    scan_folder_path = os.path.join(input_base_path, patient_id)
    output_filename = os.path.join(output_base_path, f"{patient_id}.npz")
    if os.path.exists(output_filename) and not force_preprocess: return True
    image, origin, spacing = load_scan_series(scan_folder_path)
    if image is None: return False
    resampled_image, new_spacing = resample(image, spacing, TARGET_SPACING)
    if resampled_image is None: del image; return False;
    del image;
    segmented_lungs = np.zeros_like(resampled_image, dtype=np.float32)
    for i in range(resampled_image.shape[0]): segmented_lungs[i] = get_segmented_lungs(resampled_image[i])
    del resampled_image;
    normalized_image = normalize_hu(segmented_lungs, clip_bounds=CLIP_BOUND_HU); del segmented_lungs;
    centered_image = zero_center(normalized_image, pixel_mean=PIXEL_MEAN); del normalized_image;
    final_image = resize_scan_to_target(centered_image, target_shape=FINAL_SCAN_SIZE); del centered_image;
    if final_image is None: return False
    try:
        np.savez_compressed(output_filename, image=final_image.astype(np.float32))
        return True
    except Exception as e: print(f"Error saving {patient_id}: {e}"); return False

#--- Preprocessing Execution (Identical logic, uses new path) ---
successful_processed_ids = []
failed_processed_ids = []
print(f"\nStarting preprocessing for {len(scans_to_process)} selected scans (if not already done)...")
start_time = time.time()
for patient_id in tqdm(scans_to_process, desc=f"Preprocessing {len(scans_to_process)} Scans"):
    success = preprocess_scan_dsb(patient_id, DSB_PATH, PREPROCESSED_DSB_PATH, force_preprocess=False)
    if success: successful_processed_ids.append(patient_id)
    else: failed_processed_ids.append(patient_id)
end_time = time.time()
print(f"\nPreprocessing finished/checked in {end_time - start_time:.2f} seconds.")
final_patient_list = successful_processed_ids
if not final_patient_list: raise SystemExit("No scans processed successfully. Cannot continue.")
patient_labels = {pid: patient_labels[pid] for pid in final_patient_list} # Use updated labels
print(f"Final patient count for training/validation: {len(final_patient_list)}")
final_cancer_count = sum(1 for pid in final_patient_list if patient_labels[pid] == 1)
final_non_cancer_count = len(final_patient_list) - final_cancer_count
print(f"  Balanced - Cancerous: {final_cancer_count}, Non-Cancerous: {final_non_cancer_count}")


# --- Dataset and DataLoader (Identical) ---
class PatientLevelDataset(Dataset):
    def __init__(self, patient_ids, labels_dict, preprocessed_path):
        self.patient_ids = patient_ids; self.labels_dict = labels_dict; self.preprocessed_path = preprocessed_path
    def __len__(self): return len(self.patient_ids)
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]; label = self.labels_dict[patient_id]
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")
        try:
            with np.load(scan_path) as npz_data: image = npz_data['image']
            image_tensor = torch.from_numpy(image).float().unsqueeze(0)
            label_tensor = torch.tensor(label, dtype=torch.float32)
            return image_tensor, label_tensor
        except Exception as e:
            print(f"ERROR loading {patient_id}: {e}"); dummy = torch.zeros((1, *FINAL_SCAN_SIZE), dtype=torch.float32)
            return dummy, torch.tensor(-1, dtype=torch.float32)

train_ids, val_ids = train_test_split(
    final_patient_list, test_size=0.2, random_state=SEED,
    stratify=[patient_labels[pid] for pid in final_patient_list] # Stratify on balanced labels
)
train_dataset = PatientLevelDataset(train_ids, patient_labels, PREPROCESSED_DSB_PATH)
val_dataset = PatientLevelDataset(val_ids, patient_labels, PREPROCESSED_DSB_PATH)
NUM_WORKERS = 0
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())


# --- Model Definition (CNN-ViT Hybrid WITH SE ATTENTION BLOCKS) ---

class SEBlock3D(nn.Module):
    """ Squeeze-and-Excitation Block for 3D Convolutions. """
    def __init__(self, channel, reduction=16):
        super(SEBlock3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class CNNViTHybrid3DWithAttention(nn.Module): # Renamed class
    def __init__(self,
                 input_shape_dhw=FINAL_SCAN_SIZE,
                 in_channels_cnn=1,
                 cnn_channels_setup=[16, 32, 64],
                 cnn_kernel_size=3,
                 cnn_pool_kernel_size=2,
                 se_reduction=16, # Reduction factor for SE blocks
                 patch_size_3d=(3, 4, 4),
                 embed_dim_vit=128,
                 num_transformer_layers_vit=3,
                 num_heads_vit=4,
                 mlp_ratio_vit=2.0,
                 dropout_vit=0.1,
                 num_classes=1):
        super(CNNViTHybrid3DWithAttention, self).__init__()

        self.input_shape_dhw = input_shape_dhw
        self.patch_size_3d = patch_size_3d
        self.embed_dim_vit = embed_dim_vit

        # 1. CNN Backbone with SE Blocks
        cnn_layers = []
        current_channels = in_channels_cnn
        current_d, current_h, current_w = input_shape_dhw

        for i, out_ch in enumerate(cnn_channels_setup):
            cnn_layers.extend([
                nn.Conv3d(current_channels, out_ch, kernel_size=cnn_kernel_size, stride=1, padding=cnn_kernel_size // 2),
                nn.BatchNorm3d(out_ch),
                nn.ReLU(inplace=True),
                SEBlock3D(out_ch, reduction=se_reduction), # <<< ADDED SE BLOCK
                nn.MaxPool3d(kernel_size=cnn_pool_kernel_size, stride=cnn_pool_kernel_size)
            ])
            current_channels = out_ch
            current_d //= cnn_pool_kernel_size
            current_h //= cnn_pool_kernel_size
            current_w //= cnn_pool_kernel_size
        
        self.cnn_backbone = nn.Sequential(*cnn_layers)
        self.cnn_feature_map_size_dhw = (current_d, current_h, current_w)
        cnn_out_channels_final = current_channels

        # 2. Patch Embedding for ViT
        self.patch_embed_conv = nn.Conv3d(
            cnn_out_channels_final, embed_dim_vit,
            kernel_size=patch_size_3d, stride=patch_size_3d
        )
        num_patches_d = self.cnn_feature_map_size_dhw[0] // patch_size_3d[0]
        num_patches_h = self.cnn_feature_map_size_dhw[1] // patch_size_3d[1]
        num_patches_w = self.cnn_feature_map_size_dhw[2] // patch_size_3d[2]
        self.num_patches = num_patches_d * num_patches_h * num_patches_w

        # 3. CLS Token and Positional Embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim_vit))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim_vit))
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)

        # 4. Transformer Encoder
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim_vit, nhead=num_heads_vit,
            dim_feedforward=int(embed_dim_vit * mlp_ratio_vit),
            dropout=dropout_vit, batch_first=True, activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            transformer_encoder_layer, num_layers=num_transformer_layers_vit
        )

        # 5. Classification Head
        self.norm_layer = nn.LayerNorm(embed_dim_vit)
        self.mlp_head = nn.Linear(embed_dim_vit, num_classes)
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0); nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.cnn_backbone(x)
        x = self.patch_embed_conv(x)
        x = x.flatten(2).transpose(1, 2)
        batch_size = x.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.transformer_encoder(x)
        cls_token_output = self.norm_layer(x[:, 0])
        logits = self.mlp_head(cls_token_output)
        return logits

# Instantiate the Hybrid Model with Attention
patient_model_attention = CNNViTHybrid3DWithAttention( # Use new class
    input_shape_dhw=FINAL_SCAN_SIZE,
    num_classes=NUM_CLASSES,
).to(DEVICE)

print(patient_model_attention) # Print the new model structure
try:
    dummy_input = torch.randn(BATCH_SIZE, 1, *FINAL_SCAN_SIZE).to(DEVICE)
    output = patient_model_attention(dummy_input)
    print(f"\nHybrid Model with Attention output shape: {output.shape}")
except Exception as e:
    print(f"\nError during hybrid model with attention test: {e}")
    raise

# --- Loss and Optimizer ---
# With balanced data, pos_weight should ideally be 1.0 or very close.
# Let's verify, though it should be 1.0 if data is perfectly balanced in train_ids
train_labels_list = [patient_labels[pid] for pid in train_ids]
count_0 = train_labels_list.count(0); count_1 = train_labels_list.count(1)
if count_0 == count_1 and count_0 > 0: # Perfect balance
    pos_weight_val = 1.0
elif count_1 > 0 and count_0 > 0: # Slight imbalance possible due to train/val split of balanced set
    pos_weight_val = count_0 / count_1
else: # One class missing or empty
    pos_weight_val = 1.0
    print("Warning: Training set has only one class or is empty after split. Using default pos_weight=1.")

pos_weight_tensor = torch.tensor([pos_weight_val], device=DEVICE)
print(f"Calculated positive weight for BCEWithLogitsLoss (balanced data): {pos_weight_val:.4f}")
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
optimizer = optim.AdamW(patient_model_attention.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


#--- Training and Validation Functions (Identical) ---
def train_one_epoch_patient(model, dataloader, criterion, optimizer, device, scaler):
    model.train(); running_loss = 0.0; total_samples = 0; correct_predictions = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False, ncols=100)
    for inputs, labels in progress_bar:
        valid_indices = labels != -1 
        inputs = inputs[valid_indices].to(device)
        labels = labels[valid_indices].unsqueeze(1).to(device)
        if inputs.nelement() == 0: continue 
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
            outputs = model(inputs); loss = criterion(outputs, labels)
        if torch.isnan(loss): print("NaN loss detected!"); continue
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        correct_predictions += (preds == labels.bool()).sum().item()
        progress_bar.set_postfix(loss=loss.item())
    if total_samples == 0: return 0.0, 0.0
    return running_loss / total_samples, correct_predictions / total_samples

def validate_patient(model, dataloader, criterion, device):
    model.eval(); running_loss = 0.0; total_samples = 0; all_preds_proba = []; all_labels = []
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating", leave=False, ncols=100)
        for inputs, labels in progress_bar:
            valid_indices = labels != -1
            inputs = inputs[valid_indices].to(device)
            labels = labels[valid_indices].unsqueeze(1).to(device)
            if inputs.nelement() == 0: continue
            with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
                outputs = model(inputs); loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
            all_preds_proba.extend(torch.sigmoid(outputs).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    if total_samples == 0: return 0.0, np.array([]), np.array([])
    return running_loss / total_samples, np.array(all_labels).flatten(), np.array(all_preds_proba).flatten()

#--- Training Loop ---
print(f"\nStarting Training Hybrid Model with Attention (Balanced Data) for {EPOCHS} epochs...")
best_val_loss = float('inf')
train_losses, val_losses, train_accs, val_accs_list = [], [], [], []
MODEL_SAVE_PATH = os.path.join(PREPROCESSED_DSB_PATH, "patient_level_hybrid_attention_best.pth") # New save path

if torch.cuda.is_available(): torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    start_epoch_time = time.time()
    
    train_loss, train_acc = train_one_epoch_patient(patient_model_attention, train_loader, criterion, optimizer, DEVICE, scaler)
    val_loss, val_labels_epoch, val_preds_proba_epoch = validate_patient(patient_model_attention, val_loader, criterion, DEVICE)
    
    train_losses.append(train_loss); val_losses.append(val_loss)
    train_accs.append(train_acc)
    
    end_epoch_time = time.time(); epoch_duration = end_epoch_time - start_epoch_time
    val_acc_epoch = 0.0
    if len(val_labels_epoch) > 0 and val_labels_epoch.size > 0 and val_preds_proba_epoch.size > 0:
        val_acc_epoch = accuracy_score(val_labels_epoch, (val_preds_proba_epoch > 0.5).astype(int))
    val_accs_list.append(val_acc_epoch)

    print(f"Epoch {epoch+1} Summary: Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc_epoch:.4f}")

    if val_loss < best_val_loss and len(val_labels_epoch) > 0:
        best_val_loss = val_loss
        try:
            torch.save(patient_model_attention.state_dict(), MODEL_SAVE_PATH)
            print(f"  Best model saved to {MODEL_SAVE_PATH}")
        except Exception as e: print(f"Error saving model: {e}")
            
    if torch.cuda.is_available(): torch.cuda.empty_cache()

print("\nHybrid Model with Attention (Balanced Data) Training Finished.")

# --- Plot Training History ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, EPOCHS + 1), train_losses, label='Train Loss')
plt.plot(range(1, EPOCHS + 1), val_losses, label='Val Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Loss Curve (Hybrid+Attention, Balanced)'); plt.legend(); plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(range(1, EPOCHS + 1), train_accs, label='Train Acc')
plt.plot(range(1, EPOCHS + 1), val_accs_list, label='Val Acc')
plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('Accuracy Curve (Hybrid+Attention, Balanced)'); plt.legend(); plt.grid(True)
plt.tight_layout()
plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "training_curves_hybrid_attention_balanced.png")
plt.savefig(plot_save_path); print(f"Training curves plot saved to {plot_save_path}"); plt.close()

# --- Model Evaluation ---
print("\nEvaluating Hybrid Model with Attention (Balanced Data) on Validation Set...")
if os.path.exists(MODEL_SAVE_PATH):
    try:
        eval_model = CNNViTHybrid3DWithAttention(input_shape_dhw=FINAL_SCAN_SIZE, num_classes=NUM_CLASSES).to(DEVICE)
        eval_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best hybrid attention model from {MODEL_SAVE_PATH}")
    except Exception as e:
        print(f"Could not load best model: {e}. Using last epoch model."); eval_model = patient_model_attention
else:
    print("Best model file not found. Using last epoch model."); eval_model = patient_model_attention

val_loss_final, final_val_labels, final_val_preds_proba = validate_patient(eval_model, val_loader, criterion, DEVICE)

if len(final_val_labels) == 0: print("No valid validation predictions. Cannot evaluate.")
else:
    print(f"\nFinal Validation Loss (Hybrid+Attention, Balanced): {val_loss_final:.4f}")
    final_val_preds_binary = (final_val_preds_proba > 0.5).astype(int)
    accuracy = accuracy_score(final_val_labels, final_val_preds_binary)
    precision = precision_score(final_val_labels, final_val_preds_binary, zero_division=0)
    recall = recall_score(final_val_labels, final_val_preds_binary, zero_division=0)
    f1 = f1_score(final_val_labels, final_val_preds_binary, zero_division=0)
    auc_roc = float('nan')
    if len(np.unique(final_val_labels)) > 1: # Ensure more than one class in true labels
        try: auc_roc = roc_auc_score(final_val_labels, final_val_preds_proba)
        except ValueError as e: print(f"AUC-ROC Error: {e}. Setting to NaN.")
    else: print("AUC-ROC cannot be calculated: only one class in y_true for validation.")

    print("\n--- Final Validation Metrics (Hybrid+Attention, Balanced) ---")
    print(f"Accuracy:  {accuracy:.4f}\nPrecision: {precision:.4f}\nRecall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}\nAUC-ROC:   {auc_roc:.4f}")
    target_names = ['Non-Cancer (0)', 'Cancer (1)']
    print("\nClassification Report (Hybrid+Attention, Balanced):")
    if len(np.unique(final_val_labels)) > 1:
        print(classification_report(final_val_labels, final_val_preds_binary, target_names=target_names, zero_division=0))
    else: print("Classification report not generated: only one class in y_true for validation.")

    print("\nConfusion Matrix (Hybrid+Attention, Balanced):")
    cm = confusion_matrix(final_val_labels, final_val_preds_binary, labels=[0,1])
    disp = ConfusionMatrixDisplay(cm, display_labels=target_names)
    disp.plot(cmap=plt.cm.Blues); plt.title("Confusion Matrix (Hybrid+Attention, Balanced)")
    cm_plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "confusion_matrix_hybrid_attention_balanced.png")
    plt.savefig(cm_plot_save_path); print(f"Confusion matrix plot saved to {cm_plot_save_path}"); plt.close()

    if not np.isnan(auc_roc):
        fpr, tpr, _ = roc_curve(final_val_labels, final_val_preds_proba)
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve (Hybrid+Attention, Balanced)')
        plt.legend(loc="lower right"); plt.grid(True)
        roc_plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "roc_curve_hybrid_attention_balanced.png")
        plt.savefig(roc_plot_save_path); print(f"ROC curve plot saved to {roc_plot_save_path}"); plt.close()
    else: print("ROC curve not plotted (AUC is NaN).")

print("\nScript finished.")

Using device: cuda
--- Loading Data and Selecting EVEN Number of Scans Per Class (up to 50 each) ---
Available Cancerous scans with labels: 27
Available Non-Cancerous scans with labels: 71
Selecting 27 scans from each class for balancing.
Selected 27 Cancerous scans.
Selected 27 Non-Cancerous scans.
Total scans selected for preprocessing: 54

Starting preprocessing for 54 selected scans (if not already done)...


Preprocessing 54 Scans: 100%|██████████| 54/54 [02:40<00:00,  2.97s/it]



Preprocessing finished/checked in 160.46 seconds.
Final patient count for training/validation: 54
  Balanced - Cancerous: 27, Non-Cancerous: 27
CNNViTHybrid3DWithAttention(
  (cnn_backbone): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): SEBlock3D(
      (avg_pool): AdaptiveAvgPool3d(output_size=1)
      (fc): Sequential(
        (0): Linear(in_features=16, out_features=1, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=1, out_features=16, bias=False)
        (3): Sigmoid()
      )
    )
    (4): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (6): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): SEBlock3D(

                                                                                                    

Epoch 1 Summary: Duration: 4.85s
  Train Loss: 0.7197, Train Acc: 0.3721
  Val Loss: 0.6857, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 2/50


                                                                                                    

Epoch 2 Summary: Duration: 3.34s
  Train Loss: 0.6707, Train Acc: 0.4884
  Val Loss: 0.6735, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 3/50


                                                                                                    

Epoch 3 Summary: Duration: 3.44s
  Train Loss: 0.6815, Train Acc: 0.5116
  Val Loss: 0.6803, Val Acc: 0.4545

Epoch 4/50


                                                                                                    

Epoch 4 Summary: Duration: 3.71s
  Train Loss: 0.6717, Train Acc: 0.5814
  Val Loss: 0.6814, Val Acc: 0.4545

Epoch 5/50


                                                                                                    

Epoch 5 Summary: Duration: 3.39s
  Train Loss: 0.6925, Train Acc: 0.4651
  Val Loss: 0.6692, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 6/50


                                                                                                    

Epoch 6 Summary: Duration: 3.28s
  Train Loss: 0.6776, Train Acc: 0.4884
  Val Loss: 0.6786, Val Acc: 0.5455

Epoch 7/50


                                                                                                    

Epoch 7 Summary: Duration: 3.11s
  Train Loss: 0.7018, Train Acc: 0.4651
  Val Loss: 0.6971, Val Acc: 0.4545

Epoch 8/50


                                                                                                    

Epoch 8 Summary: Duration: 3.17s
  Train Loss: 0.6814, Train Acc: 0.4651
  Val Loss: 0.6726, Val Acc: 0.5455

Epoch 9/50


                                                                                                    

Epoch 9 Summary: Duration: 3.06s
  Train Loss: 0.6773, Train Acc: 0.4651
  Val Loss: 0.6677, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 10/50


                                                                                                    

Epoch 10 Summary: Duration: 3.09s
  Train Loss: 0.6688, Train Acc: 0.6047
  Val Loss: 0.6702, Val Acc: 0.5455

Epoch 11/50


                                                                                                    

Epoch 11 Summary: Duration: 3.09s
  Train Loss: 0.6755, Train Acc: 0.4651
  Val Loss: 0.6842, Val Acc: 0.4545

Epoch 12/50


                                                                                                    

Epoch 12 Summary: Duration: 3.10s
  Train Loss: 0.6721, Train Acc: 0.4884
  Val Loss: 0.6681, Val Acc: 0.3636

Epoch 13/50


                                                                                                    

Epoch 13 Summary: Duration: 3.11s
  Train Loss: 0.6862, Train Acc: 0.4186
  Val Loss: 0.6615, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 14/50


                                                                                                    

Epoch 14 Summary: Duration: 3.10s
  Train Loss: 0.6799, Train Acc: 0.4884
  Val Loss: 0.6683, Val Acc: 0.6364

Epoch 15/50


                                                                                                    

Epoch 15 Summary: Duration: 3.08s
  Train Loss: 0.6623, Train Acc: 0.6744
  Val Loss: 0.6584, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 16/50


                                                                                                    

Epoch 16 Summary: Duration: 3.05s
  Train Loss: 0.6560, Train Acc: 0.5581
  Val Loss: 0.6580, Val Acc: 0.6364
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 17/50


                                                                                                    

Epoch 17 Summary: Duration: 3.05s
  Train Loss: 0.6420, Train Acc: 0.5814
  Val Loss: 0.6759, Val Acc: 0.5455

Epoch 18/50


                                                                                                    

Epoch 18 Summary: Duration: 3.10s
  Train Loss: 0.7051, Train Acc: 0.4651
  Val Loss: 0.6494, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 19/50


                                                                                                    

Epoch 19 Summary: Duration: 3.02s
  Train Loss: 0.6860, Train Acc: 0.5116
  Val Loss: 0.6545, Val Acc: 0.5455

Epoch 20/50


                                                                                                    

Epoch 20 Summary: Duration: 3.06s
  Train Loss: 0.6359, Train Acc: 0.6047
  Val Loss: 0.6792, Val Acc: 0.6364

Epoch 21/50


                                                                                                    

Epoch 21 Summary: Duration: 3.11s
  Train Loss: 0.6753, Train Acc: 0.5581
  Val Loss: 0.6554, Val Acc: 0.5455

Epoch 22/50


                                                                                                    

Epoch 22 Summary: Duration: 3.03s
  Train Loss: 0.6113, Train Acc: 0.6744
  Val Loss: 0.6468, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 23/50


                                                                                                    

Epoch 23 Summary: Duration: 3.05s
  Train Loss: 0.6094, Train Acc: 0.6977
  Val Loss: 0.6626, Val Acc: 0.6364

Epoch 24/50


                                                                                                    

Epoch 24 Summary: Duration: 3.07s
  Train Loss: 0.5837, Train Acc: 0.6977
  Val Loss: 0.6308, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 25/50


                                                                                                    

Epoch 25 Summary: Duration: 3.04s
  Train Loss: 0.5361, Train Acc: 0.7674
  Val Loss: 0.6514, Val Acc: 0.3636

Epoch 26/50


                                                                                                    

Epoch 26 Summary: Duration: 3.01s
  Train Loss: 0.5168, Train Acc: 0.7442
  Val Loss: 0.6808, Val Acc: 0.5455

Epoch 27/50


                                                                                                    

Epoch 27 Summary: Duration: 3.01s
  Train Loss: 0.5567, Train Acc: 0.6744
  Val Loss: 0.6927, Val Acc: 0.6364

Epoch 28/50


                                                                                                    

Epoch 28 Summary: Duration: 3.10s
  Train Loss: 0.6403, Train Acc: 0.4884
  Val Loss: 0.6726, Val Acc: 0.5455

Epoch 29/50


                                                                                                    

Epoch 29 Summary: Duration: 3.10s
  Train Loss: 0.5109, Train Acc: 0.8140
  Val Loss: 0.6047, Val Acc: 0.3636
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth

Epoch 30/50


                                                                                                    

Epoch 30 Summary: Duration: 3.04s
  Train Loss: 0.4943, Train Acc: 0.7209
  Val Loss: 0.7598, Val Acc: 0.5455

Epoch 31/50


                                                                                                    

Epoch 31 Summary: Duration: 3.03s
  Train Loss: 0.6458, Train Acc: 0.5581
  Val Loss: 0.6601, Val Acc: 0.6364

Epoch 32/50


                                                                                                    

Epoch 32 Summary: Duration: 3.02s
  Train Loss: 0.6184, Train Acc: 0.6279
  Val Loss: 0.6293, Val Acc: 0.5455

Epoch 33/50


                                                                                                    

Epoch 33 Summary: Duration: 3.02s
  Train Loss: 0.5638, Train Acc: 0.8140
  Val Loss: 0.6234, Val Acc: 0.6364

Epoch 34/50


                                                                                                    

Epoch 34 Summary: Duration: 3.02s
  Train Loss: 0.5034, Train Acc: 0.7907
  Val Loss: 0.7074, Val Acc: 0.6364

Epoch 35/50


                                                                                                    

Epoch 35 Summary: Duration: 3.00s
  Train Loss: 0.4937, Train Acc: 0.7442
  Val Loss: 0.8820, Val Acc: 0.6364

Epoch 36/50


                                                                                                    

Epoch 36 Summary: Duration: 3.10s
  Train Loss: 0.6105, Train Acc: 0.6744
  Val Loss: 0.6999, Val Acc: 0.6364

Epoch 37/50


                                                                                                    

Epoch 37 Summary: Duration: 3.10s
  Train Loss: 0.4243, Train Acc: 0.8140
  Val Loss: 0.6557, Val Acc: 0.5455

Epoch 38/50


                                                                                                    

Epoch 38 Summary: Duration: 3.02s
  Train Loss: 0.2669, Train Acc: 0.9302
  Val Loss: 0.9339, Val Acc: 0.5455

Epoch 39/50


                                                                                                    

Epoch 39 Summary: Duration: 3.02s
  Train Loss: 0.6571, Train Acc: 0.6279
  Val Loss: 0.6177, Val Acc: 0.5455

Epoch 40/50


                                                                                                    

Epoch 40 Summary: Duration: 3.07s
  Train Loss: 0.5594, Train Acc: 0.6047
  Val Loss: 0.6294, Val Acc: 0.6364

Epoch 41/50


                                                                                                    

Epoch 41 Summary: Duration: 3.04s
  Train Loss: 0.5607, Train Acc: 0.7442
  Val Loss: 0.6209, Val Acc: 0.5455

Epoch 42/50


                                                                                                    

Epoch 42 Summary: Duration: 3.04s
  Train Loss: 0.4441, Train Acc: 0.9535
  Val Loss: 0.6437, Val Acc: 0.6364

Epoch 43/50


                                                                                                    

Epoch 43 Summary: Duration: 3.01s
  Train Loss: 0.3852, Train Acc: 0.8837
  Val Loss: 0.6637, Val Acc: 0.4545

Epoch 44/50


                                                                                                    

Epoch 44 Summary: Duration: 3.01s
  Train Loss: 0.4673, Train Acc: 0.8140
  Val Loss: 0.9850, Val Acc: 0.5455

Epoch 45/50


                                                                                                    

Epoch 45 Summary: Duration: 3.09s
  Train Loss: 0.3827, Train Acc: 0.8372
  Val Loss: 0.6749, Val Acc: 0.6364

Epoch 46/50


                                                                                                    

Epoch 46 Summary: Duration: 3.04s
  Train Loss: 0.2759, Train Acc: 0.9070
  Val Loss: 0.8708, Val Acc: 0.6364

Epoch 47/50


                                                                                                    

Epoch 47 Summary: Duration: 3.01s
  Train Loss: 0.1442, Train Acc: 0.9535
  Val Loss: 0.9264, Val Acc: 0.6364

Epoch 48/50


                                                                                                    

Epoch 48 Summary: Duration: 3.03s
  Train Loss: 0.9447, Train Acc: 0.5814
  Val Loss: 0.6720, Val Acc: 0.6364

Epoch 49/50


                                                                                                    

Epoch 49 Summary: Duration: 3.10s
  Train Loss: 0.6434, Train Acc: 0.5814
  Val Loss: 0.6384, Val Acc: 0.6364

Epoch 50/50


                                                                                                    

Epoch 50 Summary: Duration: 3.04s
  Train Loss: 0.4929, Train Acc: 0.8605
  Val Loss: 0.6260, Val Acc: 0.6364

Hybrid Model with Attention (Balanced Data) Training Finished.
Training curves plot saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\training_curves_hybrid_attention_balanced.png

Evaluating Hybrid Model with Attention (Balanced Data) on Validation Set...
Loaded best hybrid attention model from C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\patient_level_hybrid_attention_best.pth


                                                                                                    


Final Validation Loss (Hybrid+Attention, Balanced): 0.6047

--- Final Validation Metrics (Hybrid+Attention, Balanced) ---
Accuracy:  0.3636
Precision: 0.2500
Recall:    0.2000
F1-Score:  0.2222
AUC-ROC:   0.5333

Classification Report (Hybrid+Attention, Balanced):
                precision    recall  f1-score   support

Non-Cancer (0)       0.43      0.50      0.46         6
    Cancer (1)       0.25      0.20      0.22         5

      accuracy                           0.36        11
     macro avg       0.34      0.35      0.34        11
  weighted avg       0.35      0.36      0.35        11


Confusion Matrix (Hybrid+Attention, Balanced):
Confusion matrix plot saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\confusion_matrix_hybrid_attention_balanced.png
ROC curve plot saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Balanced_Attention\roc_curve_hybrid_attention_balanced.png

Script finished.
