In [3]:
import os
import glob
import time
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
import random
import scipy.ndimage
from skimage.measure import label as skimage_label, regionprops
from skimage.morphology import disk, binary_closing
from skimage.segmentation import clear_border
import scipy.ndimage as ndi

import torch
import torch.nn as nn
import torch.nn.functional as F # For pad
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, precision_recall_curve, auc, f1_score,
                             precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay)
# from einops import rearrange, repeat # Not strictly needed for this version

# Configuration
# --- MODIFY THESE PATHS ---
DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Stages"
DSB_LABELS_CSV = r"C:\Users\rouaa\Documents\Final_Pneumatect\stage1_labels.csv"
# --- MODIFIED: New output path for Swin3D model ---
PREPROCESSED_DSB_PATH = r"C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected" # Added _Corrected
# ---

# Preprocessing & Model Params (mostly unchanged)
TARGET_SPACING = [1.5, 1.5, 1.5]
FINAL_SCAN_SIZE = (96, 128, 128) # (Depth, Height, Width)
CLIP_BOUND_HU = [-1000.0, 400.0]
PIXEL_MEAN = 0.25

# Training Params
NUM_CLASSES = 1
BATCH_SIZE = 2 # Swin Transformers can be memory intensive
LEARNING_RATE = 0.0001
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

SCAN_LIMIT_PER_CLASS = 50 # Max scans per class if available

# Ensure output directory exists
os.makedirs(PREPROCESSED_DSB_PATH, exist_ok=True)

# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Data Loading and Selection (BALANCED, Identical to previous script) ---
print(f"--- Loading Data and Selecting EVEN Number of Scans Per Class (up to {SCAN_LIMIT_PER_CLASS} each) ---")
if not os.path.isdir(DSB_PATH): raise SystemExit(f"ERROR: DSB Scans path not found: {DSB_PATH}")
if not os.path.isfile(DSB_LABELS_CSV): raise SystemExit(f"ERROR: DSB Labels CSV not found: {DSB_LABELS_CSV}")
dsb_labels_df = pd.read_csv(DSB_LABELS_CSV)
dsb_labels_df = dsb_labels_df.rename(columns={'id': 'patient_id'})
patient_labels_all = dsb_labels_df.set_index('patient_id')['cancer'].to_dict()
scan_folders = [f for f in os.listdir(DSB_PATH) if os.path.isdir(os.path.join(DSB_PATH, f))]
found_scan_ids = set(scan_folders)
labeled_patient_ids_all = set(dsb_labels_df['patient_id'])
common_ids_all = labeled_patient_ids_all.intersection(found_scan_ids)
common_ids_cancer_available = [pid for pid in common_ids_all if patient_labels_all.get(pid) == 1]
common_ids_non_cancer_available = [pid for pid in common_ids_all if patient_labels_all.get(pid) == 0]
random.shuffle(common_ids_cancer_available)
random.shuffle(common_ids_non_cancer_available)
num_to_select_per_class = min(len(common_ids_cancer_available), len(common_ids_non_cancer_available), SCAN_LIMIT_PER_CLASS)
print(f"Available Cancerous: {len(common_ids_cancer_available)}, Non-Cancerous: {len(common_ids_non_cancer_available)}")
print(f"Selecting {num_to_select_per_class} from each class.")
selected_cancer_ids = common_ids_cancer_available[:num_to_select_per_class]
selected_non_cancer_ids = common_ids_non_cancer_available[:num_to_select_per_class]
scans_to_process = selected_cancer_ids + selected_non_cancer_ids
random.shuffle(scans_to_process)
print(f"Total scans selected: {len(scans_to_process)}")
patient_labels = {pid: patient_labels_all[pid] for pid in scans_to_process}

# --- Preprocessing Functions (Identical) ---
def load_scan_series(dicom_folder_path):
    try:
        series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(dicom_folder_path)
        if not series_ids: return None, None, None
        series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(dicom_folder_path, series_ids[0])
        series_reader = sitk.ImageSeriesReader(); series_reader.SetFileNames(series_file_names)
        itkimage = series_reader.Execute()
        image_array = sitk.GetArrayFromImage(itkimage); origin = np.array(list(reversed(itkimage.GetOrigin()))); spacing = np.array(list(reversed(itkimage.GetSpacing())))
        return image_array, origin, spacing
    except Exception as e: print(f"Error reading DICOM {os.path.basename(dicom_folder_path)}: {e}"); return None, None, None

def resample(image, original_spacing, new_spacing=TARGET_SPACING):
    try:
        resize_factor = np.array(original_spacing) / np.array(new_spacing)
        new_real_shape = image.shape * resize_factor; new_shape = np.round(new_real_shape)
        real_resize_factor = new_shape / image.shape; actual_new_spacing = original_spacing / real_resize_factor
        resampled_image = scipy.ndimage.zoom(image, real_resize_factor, mode='nearest', order=1)
        return resampled_image, actual_new_spacing
    except Exception as e: print(f"Error resamping: {e}"); return None, None

def get_segmented_lungs(im_slice, hu_threshold=-320):
    if im_slice.ndim != 2: return im_slice
    binary = im_slice < hu_threshold; cleared = clear_border(binary)
    label_image = skimage_label(cleared); areas = [r.area for r in regionprops(label_image)]; areas.sort()
    area_threshold = areas[-2] if len(areas) >= 2 else (areas[-1] if len(areas) == 1 else 0)
    if area_threshold > 0:
        for region in regionprops(label_image):
            if region.area < area_threshold:
                for coordinates in region.coords: label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0; selem = disk(2); binary = binary_closing(binary, selem)
    selem_dilate = disk(5); final_mask = ndi.binary_dilation(binary, structure=selem_dilate)
    background_val = CLIP_BOUND_HU[0] - 1; segmented_slice = im_slice.copy()
    segmented_slice[final_mask == 0] = background_val
    return segmented_slice

def normalize_hu(image, clip_bounds=CLIP_BOUND_HU):
    min_bound, max_bound = clip_bounds; image = np.clip(image, min_bound, max_bound)
    image = (image - min_bound) / (max_bound - min_bound)
    return image.astype(np.float32)

def zero_center(image, pixel_mean=PIXEL_MEAN):
    image = image - pixel_mean
    return image.astype(np.float32)

def resize_scan_to_target(image, target_shape=FINAL_SCAN_SIZE):
    if image.shape == target_shape: return image
    resize_factor = np.array(target_shape) / np.array(image.shape)
    try:
        resized_image = scipy.ndimage.zoom(image, resize_factor, order=1, mode='nearest')
        if resized_image.shape != target_shape:
            current_shape = resized_image.shape; diff = np.array(target_shape) - np.array(current_shape)
            pad = np.maximum(diff, 0); crop = np.maximum(-diff, 0)
            pad_width = tuple((p // 2, p - p // 2) for p in pad)
            resized_image = np.pad(resized_image, pad_width, mode='edge')
            crop_slice = tuple(slice(c // 2, s - (c - c // 2)) for c, s in zip(crop, resized_image.shape))
            resized_image = resized_image[crop_slice]
        if resized_image.shape != target_shape: print(f"ERROR: Resize failed. Shape {resized_image.shape} vs Target {target_shape}"); return None
        return resized_image.astype(np.float32)
    except Exception as e: print(f"Error resizing to target: {e}"); return None

def preprocess_scan_dsb(patient_id, input_base_path, output_base_path, force_preprocess=False):
    scan_folder_path = os.path.join(input_base_path, patient_id)
    output_filename = os.path.join(output_base_path, f"{patient_id}.npz")
    if os.path.exists(output_filename) and not force_preprocess: return True
    image, origin, spacing = load_scan_series(scan_folder_path);
    if image is None: return False
    resampled_image, new_spacing = resample(image, spacing, TARGET_SPACING)
    if resampled_image is None: del image; return False;
    del image; segmented_lungs = np.zeros_like(resampled_image, dtype=np.float32)
    for i in range(resampled_image.shape[0]): segmented_lungs[i] = get_segmented_lungs(resampled_image[i])
    del resampled_image; normalized_image = normalize_hu(segmented_lungs, clip_bounds=CLIP_BOUND_HU); del segmented_lungs;
    centered_image = zero_center(normalized_image, pixel_mean=PIXEL_MEAN); del normalized_image;
    final_image = resize_scan_to_target(centered_image, target_shape=FINAL_SCAN_SIZE); del centered_image;
    if final_image is None: return False
    try: np.savez_compressed(output_filename, image=final_image.astype(np.float32)); return True
    except Exception as e: print(f"Error saving {patient_id}: {e}"); return False

#--- Preprocessing Execution ---
successful_processed_ids = []
print(f"\nPreprocessing for {len(scans_to_process)} scans (if not already done)...")
start_time = time.time()
for patient_id in tqdm(scans_to_process, desc=f"Preprocessing"):
    if preprocess_scan_dsb(patient_id, DSB_PATH, PREPROCESSED_DSB_PATH, force_preprocess=False):
        successful_processed_ids.append(patient_id)
end_time = time.time()
print(f"\nPreprocessing finished/checked in {end_time - start_time:.2f} seconds.")
final_patient_list = successful_processed_ids
if not final_patient_list: raise SystemExit("No scans processed. Cannot continue.")
patient_labels = {pid: patient_labels[pid] for pid in final_patient_list}
print(f"Final patient count for training/validation: {len(final_patient_list)}")

# --- Dataset and DataLoader (Identical) ---
class PatientLevelDataset(Dataset):
    def __init__(self, patient_ids, labels_dict, preprocessed_path):
        self.patient_ids = patient_ids; self.labels_dict = labels_dict; self.preprocessed_path = preprocessed_path
    def __len__(self): return len(self.patient_ids)
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]; label = self.labels_dict[patient_id]
        scan_path = os.path.join(self.preprocessed_path, f"{patient_id}.npz")
        try:
            with np.load(scan_path) as npz_data: image = npz_data['image']
            image_tensor = torch.from_numpy(image).float().unsqueeze(0)
            label_tensor = torch.tensor(label, dtype=torch.float32)
            return image_tensor, label_tensor
        except Exception as e:
            print(f"ERROR loading {patient_id}: {e}"); dummy = torch.zeros((1, *FINAL_SCAN_SIZE), dtype=torch.float32)
            return dummy, torch.tensor(-1, dtype=torch.float32)

train_ids, val_ids = train_test_split(final_patient_list, test_size=0.2, random_state=SEED,
                                      stratify=[patient_labels[pid] for pid in final_patient_list])
train_dataset = PatientLevelDataset(train_ids, patient_labels, PREPROCESSED_DSB_PATH)
val_dataset = PatientLevelDataset(val_ids, patient_labels, PREPROCESSED_DSB_PATH)
NUM_WORKERS = 0
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())


# --- 3D Swin Transformer Model (Corrected) ---

def window_partition_3d(x, window_size):
    """
    Args:
        x: (B, D, H, W, C)
        window_size (tuple[int]): window size (Depth, Height, Width)
    Returns:
        windows: (num_windows*B, window_size_d, window_size_h, window_size_w, C)
    """
    B, D, H, W, C = x.shape
    wd, wh, ww = window_size
    # Ensure D, H, W are divisible by window_size dimensions for this simplified partition
    assert D % wd == 0 and H % wh == 0 and W % ww == 0, \
        f"Input shape ({D},{H},{W}) not divisible by window_size ({wd},{wh},{ww})"
    x = x.view(B, D // wd, wd, H // wh, wh, W // ww, ww, C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, wd, wh, ww, C)
    return windows

def window_reverse_3d(windows, window_size, D, H, W):
    """
    Args:
        windows: (num_windows*B, window_size_d, window_size_h, window_size_w, C)
        window_size (tuple[int]): Window size (Depth, Height, Width)
        D, H, W: Original image dimensions (before windowing)
    Returns:
        x: (B, D, H, W, C)
    """
    wd, wh, ww = window_size
    # Calculate B based on the total number of elements and known dimensions
    num_windows_per_sample = (D // wd) * (H // wh) * (W // ww)
    if num_windows_per_sample == 0 : # Should not happen if assertions in partition hold
        B = 0
    else:
        B = int(windows.shape[0] / num_windows_per_sample)

    x = windows.view(B, D // wd, H // wh, W // ww, wd, wh, ww, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
    return x

class WindowAttention3D(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wd, Wh, Ww (tuple)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), num_heads))

        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        coords = torch.stack(torch.meshgrid([coords_d, coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            N, N, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock3D(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=(4,4,4), shift_size=(0,0,0),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution # (Depth, Height, Width)
        self.num_heads = num_heads
        self.window_size = window_size if isinstance(window_size, tuple) else (window_size,) * 3
        self.shift_size = shift_size if isinstance(shift_size, tuple) else (shift_size,) * 3
        self.mlp_ratio = mlp_ratio

        # Ensure input_resolution is divisible by window_size for this block
        for i in range(3):
            assert input_resolution[i] % self.window_size[i] == 0, \
                f"Input dim {input_resolution[i]} not divisible by window_size {self.window_size[i]} at dim {i}"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention3D(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Identity() if drop_path == 0. else DropPath(drop_path)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim), act_layer(), nn.Dropout(drop), # Added dropout in MLP
            nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop)
        )

        if any(s > 0 for s in self.shift_size):
            D, H, W = self.input_resolution
            img_mask = torch.zeros((1, D, H, W, 1)) # Create on CPU first
            
            # Slicing for mask creation, ensuring robustness
            def get_slices(dim_size, win_size, sh_size):
                if sh_size == 0 : # no shift, one segment
                    return [slice(0, dim_size)]
                # If dim_size is small, it implies fewer than 3 segments
                if dim_size <= win_size: # Only one window segment if dim is small
                     return [slice(0, dim_size)] # Could also be more complex if sh_size>0
                else: # Standard Swin logic if dim > win_size
                    return (slice(0, -win_size),
                            slice(-win_size, -sh_size),
                            slice(-sh_size, None))

            slices_d = get_slices(D, self.window_size[0], self.shift_size[0])
            slices_h = get_slices(H, self.window_size[1], self.shift_size[1])
            slices_w = get_slices(W, self.window_size[2], self.shift_size[2])
            
            cnt = 0
            for d_s in slices_d:
                for h_s in slices_h:
                    for w_s in slices_w:
                        img_mask[:, d_s, h_s, w_s, :] = cnt
                        cnt += 1
            
            mask_windows = window_partition_3d(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask, persistent=False)


    def forward(self, x):
        D_orig, H_orig, W_orig = self.input_resolution
        B, L, C = x.shape
        assert L == D_orig * H_orig * W_orig, f"Input feature has wrong size {L} vs {D_orig*H_orig*W_orig}"

        shortcut = x
        x_norm = self.norm1(x)
        x_reshaped = x_norm.view(B, D_orig, H_orig, W_orig, C)

        if any(s > 0 for s in self.shift_size):
            shifted_x = torch.roll(x_reshaped, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
        else:
            shifted_x = x_reshaped

        x_windows = window_partition_3d(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)
        
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows_reshaped = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C)
        shifted_x_merged = window_reverse_3d(attn_windows_reshaped, self.window_size, D_orig, H_orig, W_orig)

        if any(s > 0 for s in self.shift_size):
            x_reversed_shift = torch.roll(shifted_x_merged, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
        else:
            x_reversed_shift = shifted_x_merged
        
        x_output_attn = x_reversed_shift.view(B, D_orig * H_orig * W_orig, C)

        x = shortcut + self.drop_path(x_output_attn)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class PatchMerging3D(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        # Merge 2x2x2 patches, channel dim becomes 8*dim, then projected to 2*dim
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        D, H, W = self.input_resolution
        B, L, C = x.shape
        assert L == D * H * W, "input feature has wrong size"
        assert D % 2 == 0 and H % 2 == 0 and W % 2 == 0, \
            f"Input dimensions ({D}x{H}x{W}) are not all even for PatchMerging."

        x = x.view(B, D, H, W, C)
        x0 = x[:, 0::2, 0::2, 0::2, :]
        x1 = x[:, 0::2, 0::2, 1::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 0::2, 1::2, 1::2, :]
        x4 = x[:, 1::2, 0::2, 0::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 1::2, 1::2, 0::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
        x = x.view(B, -1, 8 * C)

        x = self.norm(x)
        x = self.reduction(x)
        return x

class BasicLayer3D(nn.Module):
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth

        self.blocks = nn.ModuleList([
            SwinTransformerBlock3D(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=tuple(0 if (i % 2 == 0) else w // 2 for w in window_size),
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x_down = self.downsample(x)
            return x, x_down # Return pre-downsample for skip connection if needed by a larger model
        return x, x # If no downsample, return x twice to match signature


class PatchEmbed3D(nn.Module):
    def __init__(self, img_size=(96,128,128), patch_size=(4,4,4), in_chans=1, embed_dim=96, norm_layer=None):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = [img_size[0] // patch_size[0],
                                   img_size[1] // patch_size[1],
                                   img_size[2] // patch_size[2]]
        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] * self.patches_resolution[2]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = nn.Identity() # Use nn.Identity if norm_layer is None

    def forward(self, x):
        B, C, D, H, W = x.shape
        assert D == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2], \
            f"Input image size ({D}x{H}x{W}) doesn't match model ({self.img_size[0]}x{self.img_size[1]}x{self.img_size[2]})."
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

# DropPath implementation (if not available in your torch version or for self-containment)
def DropPath(drop_prob: float = 0., scale_by_keep: bool = True):
    if drop_prob == 0. or not True: # Assuming True means model.train()
        return nn.Identity()
    return DropPathLayer(drop_prob, scale_by_keep)

class DropPathLayer(nn.Module):
    def __init__(self, drop_prob=0., scale_by_keep=True):
        super(DropPathLayer, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor

    def extra_repr(self) -> str:
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


class SwinTransformer3D(nn.Module):
    def __init__(self, img_size=FINAL_SCAN_SIZE, patch_size=(4,4,4), in_chans=1, num_classes=NUM_CLASSES,
                 embed_dim=48, depths=[2, 2, 6], num_heads=[3, 6, 12],
                 window_size=(4,4,4), mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        self.patch_embed = PatchEmbed3D(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        self.patches_resolution = self.patch_embed.patches_resolution # Store for use in layers

        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
        else:
            self.absolute_pos_embed = None


        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.layers = nn.ModuleList()
        current_resolution = self.patches_resolution
        for i_layer in range(self.num_layers):
            layer = BasicLayer3D(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=current_resolution,
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size, # Use the passed window_size
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging3D if (i_layer < self.num_layers - 1) else None)
            self.layers.append(layer)
            if i_layer < self.num_layers - 1: # Update resolution for the next layer
                current_resolution = tuple(res // 2 for res in current_resolution)


        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.absolute_pos_embed is not None:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            _, x = layer(x) # BasicLayer3D now returns pre-downsample (ignored here) and post-downsample

        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


# Instantiate the Swin3D Model
# The key is that window_size must be compatible with feature map dimensions at ALL stages.
# If img=(96,128,128), patch=(4,8,8) => patches_res=(24,16,16)
# Stage 0 input: (24,16,16)
# Stage 1 input: (12,8,8) (after 1 PatchMerging)
# Stage 2 input: (6,4,4) (after 2 PatchMerging)
# A window_size like (3,4,4) would work:
# - Stage 0: (24,16,16) is div by (3,4,4)
# - Stage 1: (12,8,8) is div by (3,4,4)
# - Stage 2: (6,4,4) means 6 is div by 3, 4 by 4.
# So, window_size=(3,4,4) seems appropriate.

swin_model = SwinTransformer3D(
    img_size=FINAL_SCAN_SIZE,
    patch_size=(4,8,8),
    embed_dim=64,
    depths=[2, 2, 2],  # Reduced depth further for faster iteration/less memory
    num_heads=[4, 8, 16],
    window_size=(3,4,4), # Chosen to be compatible with (6,4,4)
    num_classes=NUM_CLASSES,
    drop_path_rate=0.1, # Increased slightly
    ape=False
).to(DEVICE)

print(f"Swin3D Model Instantiated. Number of parameters: {sum(p.numel() for p in swin_model.parameters() if p.requires_grad)}")
try:
    dummy_input = torch.randn(BATCH_SIZE, 1, *FINAL_SCAN_SIZE).to(DEVICE)
    output = swin_model(dummy_input)
    print(f"\nSwin3D Model output shape: {output.shape}") # Expected: (B, num_classes)
except Exception as e:
    print(f"\nError during Swin3D model test: {e}")
    raise

# --- Loss and Optimizer ---
train_labels_list = [patient_labels[pid] for pid in train_ids]
count_0 = train_labels_list.count(0); count_1 = train_labels_list.count(1)
# pos_weight should be close to 1.0 if train_ids are balanced
pos_weight_val = (count_0 / count_1) if count_1 > 0 and count_0 > 0 and count_0 != count_1 else 1.0
pos_weight_tensor = torch.tensor([pos_weight_val], device=DEVICE)
print(f"Calculated positive weight for BCEWithLogitsLoss (Swin3D): {pos_weight_val:.4f}")
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
# Swin often uses AdamW with higher weight decay
optimizer = optim.AdamW(swin_model.parameters(), lr=LEARNING_RATE, weight_decay=0.05) # Swin default wd
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


#--- Training and Validation Functions (Identical) ---
def train_one_epoch_patient(model, dataloader, criterion, optimizer, device, scaler):
    model.train(); running_loss = 0.0; total_samples = 0; correct_predictions = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False, ncols=100)
    for inputs, labels in progress_bar:
        valid_indices = labels != -1
        inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device)
        if inputs.nelement() == 0: continue
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
            outputs = model(inputs); loss = criterion(outputs, labels)
        if torch.isnan(loss) or torch.isinf(loss): print(f"Invalid loss detected: {loss.item()}! Skipping batch.");torch.cuda.empty_cache(); continue
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5; correct_predictions += (preds == labels.bool()).sum().item()
        progress_bar.set_postfix(loss=loss.item())
    if total_samples == 0: return 0.0, 0.0
    return running_loss / total_samples, correct_predictions / total_samples

def validate_patient(model, dataloader, criterion, device):
    model.eval(); running_loss = 0.0; total_samples = 0; all_preds_proba = []; all_labels = []
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating", leave=False, ncols=100)
        for inputs, labels in progress_bar:
            valid_indices = labels != -1
            inputs = inputs[valid_indices].to(device); labels = labels[valid_indices].unsqueeze(1).to(device)
            if inputs.nelement() == 0: continue
            with torch.amp.autocast(device_type=device.type, dtype=torch.float16, enabled=torch.cuda.is_available()):
                outputs = model(inputs); loss = criterion(outputs, labels)
            if torch.isnan(loss) or torch.isinf(loss): print(f"Invalid val_loss detected: {loss.item()}! Skipping batch."); continue
            running_loss += loss.item() * inputs.size(0); total_samples += inputs.size(0)
            all_preds_proba.extend(torch.sigmoid(outputs).cpu().numpy()); all_labels.extend(labels.cpu().numpy())
    if total_samples == 0: return 0.0, np.array([]), np.array([])
    return running_loss / total_samples, np.array(all_labels).flatten(), np.array(all_preds_proba).flatten()

#--- Training Loop ---
print(f"\nStarting Training Swin3D Model (Balanced Data) for {EPOCHS} epochs...")
best_val_loss = float('inf')
train_losses, val_losses, train_accs, val_accs_list = [], [], [], []
MODEL_SAVE_PATH = os.path.join(PREPROCESSED_DSB_PATH, "swin3d_model_corrected_best.pth") # Added _corrected

if torch.cuda.is_available(): torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    start_epoch_time = time.time()
    train_loss, train_acc = train_one_epoch_patient(swin_model, train_loader, criterion, optimizer, DEVICE, scaler)
    val_loss, val_labels_epoch, val_preds_proba_epoch = validate_patient(swin_model, val_loader, criterion, DEVICE)
    train_losses.append(train_loss); val_losses.append(val_loss); train_accs.append(train_acc)
    end_epoch_time = time.time(); epoch_duration = end_epoch_time - start_epoch_time
    val_acc_epoch = 0.0
    if len(val_labels_epoch) > 0 and val_labels_epoch.size > 0 and val_preds_proba_epoch.size > 0:
        val_acc_epoch = accuracy_score(val_labels_epoch, (val_preds_proba_epoch > 0.5).astype(int))
    val_accs_list.append(val_acc_epoch)
    print(f"Epoch {epoch+1} Summary: Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc_epoch:.4f}")
    if val_loss < best_val_loss and len(val_labels_epoch) > 0 and not (torch.isinf(torch.tensor(val_loss)) or torch.isnan(torch.tensor(val_loss))): # Ensure val_loss is valid
        best_val_loss = val_loss
        try: torch.save(swin_model.state_dict(), MODEL_SAVE_PATH); print(f"  Best model saved to {MODEL_SAVE_PATH}")
        except Exception as e: print(f"Error saving model: {e}")
    if torch.cuda.is_available(): torch.cuda.empty_cache()

print("\nSwin3D Model (Balanced Data) Training Finished.")

# --- Plot Training History & Evaluation ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1); plt.plot(range(1, EPOCHS + 1), train_losses, label='Train Loss'); plt.plot(range(1, EPOCHS + 1), val_losses, label='Val Loss'); plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Loss Curve (Swin3D Corrected, Balanced)'); plt.legend(); plt.grid(True)
plt.subplot(1, 2, 2); plt.plot(range(1, EPOCHS + 1), train_accs, label='Train Acc'); plt.plot(range(1, EPOCHS + 1), val_accs_list, label='Val Acc'); plt.xlabel('Epochs'); plt.ylabel('Accuracy'); plt.title('Accuracy Curve (Swin3D Corrected, Balanced)'); plt.legend(); plt.grid(True)
plt.tight_layout(); plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "training_curves_swin3d_corrected_balanced.png")
plt.savefig(plot_save_path); print(f"Training curves saved to {plot_save_path}"); plt.close()

print("\nEvaluating Swin3D Model (Balanced Data) on Validation Set...")
if os.path.exists(MODEL_SAVE_PATH):
    try: # Re-instantiate with the same parameters as trained
        eval_model = SwinTransformer3D(
            img_size=FINAL_SCAN_SIZE, patch_size=(4,8,8), embed_dim=64,
            depths=[2,2,2], num_heads=[4,8,16], window_size=(3,4,4),
            num_classes=NUM_CLASSES, drop_path_rate=0.1, ape=False
        ).to(DEVICE)
        eval_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best Swin3D model from {MODEL_SAVE_PATH}")
    except Exception as e: print(f"Could not load best model: {e}. Using last epoch model."); eval_model = swin_model
else: print("Best model file not found. Using last epoch model."); eval_model = swin_model

val_loss_final, final_val_labels, final_val_preds_proba = validate_patient(eval_model, val_loader, criterion, DEVICE)
if len(final_val_labels) == 0: print("No valid validation predictions.")
else:
    print(f"\nFinal Validation Loss (Swin3D Corrected, Balanced): {val_loss_final:.4f}")
    final_val_preds_binary = (final_val_preds_proba > 0.5).astype(int)
    accuracy = accuracy_score(final_val_labels, final_val_preds_binary)
    precision = precision_score(final_val_labels, final_val_preds_binary, zero_division=0)
    recall = recall_score(final_val_labels, final_val_preds_binary, zero_division=0)
    f1 = f1_score(final_val_labels, final_val_preds_binary, zero_division=0); auc_roc = float('nan')
    if len(np.unique(final_val_labels)) > 1:
        try: auc_roc = roc_auc_score(final_val_labels, final_val_preds_proba)
        except ValueError as e: print(f"AUC-ROC Error: {e}.")
    else: print("AUC-ROC not calculated: only one class in y_true.")
    print("\n--- Final Validation Metrics (Swin3D Corrected, Balanced) ---")
    print(f"Accuracy:  {accuracy:.4f}\nPrecision: {precision:.4f}\nRecall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}\nAUC-ROC:   {auc_roc:.4f}")
    target_names = ['Non-Cancer (0)', 'Cancer (1)']
    print("\nClassification Report (Swin3D Corrected, Balanced):")
    if len(np.unique(final_val_labels)) > 1: print(classification_report(final_val_labels, final_val_preds_binary, target_names=target_names, zero_division=0))
    else: print("Classification report not generated: only one class in y_true.")
    print("\nConfusion Matrix (Swin3D Corrected, Balanced):")
    cm = confusion_matrix(final_val_labels, final_val_preds_binary, labels=[0,1]) # Ensure labels for CM
    disp = ConfusionMatrixDisplay(cm, display_labels=target_names); disp.plot(cmap=plt.cm.Blues); plt.title("Confusion Matrix (Swin3D Corrected, Balanced)")
    cm_plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "confusion_matrix_swin3d_corrected_balanced.png")
    plt.savefig(cm_plot_save_path); print(f"Confusion matrix plot saved to {cm_plot_save_path}"); plt.close()
    if not np.isnan(auc_roc) and not (torch.isinf(torch.tensor(auc_roc))): # check for inf as well
        fpr, tpr, _ = roc_curve(final_val_labels, final_val_preds_proba)
        plt.figure(figsize=(8,6)); plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})'); plt.plot([0,1],[0,1],color='navy',lw=2,linestyle='--'); plt.xlim([0.0,1.0]); plt.ylim([0.0,1.05]); plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve (Swin3D Corrected, Balanced)'); plt.legend(loc="lower right"); plt.grid(True)
        roc_plot_save_path = os.path.join(PREPROCESSED_DSB_PATH, "roc_curve_swin3d_corrected_balanced.png")
        plt.savefig(roc_plot_save_path); print(f"ROC curve plot saved to {roc_plot_save_path}"); plt.close()
    else: print("ROC curve not plotted (AUC is NaN or Inf).")
print("\nScript finished.")

Using device: cuda
--- Loading Data and Selecting EVEN Number of Scans Per Class (up to 50 each) ---
Available Cancerous: 27, Non-Cancerous: 71
Selecting 27 from each class.
Total scans selected: 54

Preprocessing for 54 scans (if not already done)...


Preprocessing: 100%|██████████| 54/54 [03:08<00:00,  3.49s/it]



Preprocessing finished/checked in 188.68 seconds.
Final patient count for training/validation: 54
Swin3D Model Instantiated. Number of parameters: 2437849

Swin3D Model output shape: torch.Size([2, 1])
Calculated positive weight for BCEWithLogitsLoss (Swin3D): 0.9545


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())



Starting Training Swin3D Model (Balanced Data) for 50 epochs...

Epoch 1/50


                                                                                                    

Epoch 1 Summary: Duration: 5.11s
  Train Loss: 0.8260, Train Acc: 0.4186
  Val Loss: 0.6914, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 2/50


                                                                                                    

Epoch 2 Summary: Duration: 3.07s
  Train Loss: 0.7017, Train Acc: 0.3953
  Val Loss: 0.6913, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 3/50


                                                                                                    

Epoch 3 Summary: Duration: 3.19s
  Train Loss: 0.6874, Train Acc: 0.4419
  Val Loss: 0.6841, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 4/50


                                                                                                    

Epoch 4 Summary: Duration: 2.94s
  Train Loss: 0.7270, Train Acc: 0.5116
  Val Loss: 0.6803, Val Acc: 0.4545
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 5/50


                                                                                                    

Epoch 5 Summary: Duration: 3.09s
  Train Loss: 0.7026, Train Acc: 0.5116
  Val Loss: 0.6797, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 6/50


                                                                                                    

Epoch 6 Summary: Duration: 3.17s
  Train Loss: 0.7067, Train Acc: 0.6047
  Val Loss: 0.6792, Val Acc: 0.5455
  Best model saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth

Epoch 7/50


                                                                                                    

Epoch 7 Summary: Duration: 3.05s
  Train Loss: 0.6818, Train Acc: 0.4884
  Val Loss: 0.6809, Val Acc: 0.5455

Epoch 8/50


                                                                                                    

Epoch 8 Summary: Duration: 2.96s
  Train Loss: 0.6926, Train Acc: 0.4651
  Val Loss: 0.6845, Val Acc: 0.3636

Epoch 9/50


                                                                                                    

Epoch 9 Summary: Duration: 2.94s
  Train Loss: 0.6766, Train Acc: 0.4884
  Val Loss: 0.6796, Val Acc: 0.5455

Epoch 10/50


                                                                                                    

Epoch 10 Summary: Duration: 2.96s
  Train Loss: 0.6769, Train Acc: 0.5116
  Val Loss: 0.6989, Val Acc: 0.4545

Epoch 11/50


                                                                                                    

Epoch 11 Summary: Duration: 2.90s
  Train Loss: 0.6690, Train Acc: 0.6279
  Val Loss: 0.6947, Val Acc: 0.4545

Epoch 12/50


                                                                                                    

Epoch 12 Summary: Duration: 2.97s
  Train Loss: 0.6472, Train Acc: 0.7442
  Val Loss: 0.7088, Val Acc: 0.1818

Epoch 13/50


                                                                                                    

Epoch 13 Summary: Duration: 2.92s
  Train Loss: 0.6339, Train Acc: 0.6279
  Val Loss: 0.7324, Val Acc: 0.1818

Epoch 14/50


                                                                                                    

Epoch 14 Summary: Duration: 2.94s
  Train Loss: 0.4974, Train Acc: 0.7907
  Val Loss: 1.0938, Val Acc: 0.4545

Epoch 15/50


                                                                                                    

Epoch 15 Summary: Duration: 3.03s
  Train Loss: 0.2839, Train Acc: 0.9070
  Val Loss: 1.6706, Val Acc: 0.1818

Epoch 16/50


                                                                                                    

Epoch 16 Summary: Duration: 3.01s
  Train Loss: 0.2110, Train Acc: 0.8605
  Val Loss: 1.5032, Val Acc: 0.2727

Epoch 17/50


                                                                                                    

Epoch 17 Summary: Duration: 2.92s
  Train Loss: 0.1434, Train Acc: 0.9302
  Val Loss: 1.7229, Val Acc: 0.3636

Epoch 18/50


                                                                                                    

Epoch 18 Summary: Duration: 2.97s
  Train Loss: 0.1983, Train Acc: 0.9535
  Val Loss: 2.1522, Val Acc: 0.0909

Epoch 19/50


                                                                                                    

Epoch 19 Summary: Duration: 3.59s
  Train Loss: 0.1254, Train Acc: 0.9302
  Val Loss: 1.8719, Val Acc: 0.2727

Epoch 20/50


                                                                                                    

Epoch 20 Summary: Duration: 4.57s
  Train Loss: 0.0274, Train Acc: 1.0000
  Val Loss: 2.0059, Val Acc: 0.3636

Epoch 21/50


                                                                                                    

Epoch 21 Summary: Duration: 4.43s
  Train Loss: 0.0367, Train Acc: 0.9767
  Val Loss: 2.0584, Val Acc: 0.2727

Epoch 22/50


                                                                                                    

Epoch 22 Summary: Duration: 3.64s
  Train Loss: 0.0205, Train Acc: 1.0000
  Val Loss: 2.0950, Val Acc: 0.2727

Epoch 23/50


                                                                                                    

Epoch 23 Summary: Duration: 2.97s
  Train Loss: 0.0528, Train Acc: 0.9767
  Val Loss: 2.1016, Val Acc: 0.3636

Epoch 24/50


                                                                                                    

Epoch 24 Summary: Duration: 2.87s
  Train Loss: 0.0115, Train Acc: 1.0000
  Val Loss: 1.9019, Val Acc: 0.3636

Epoch 25/50


                                                                                                    

Epoch 25 Summary: Duration: 2.80s
  Train Loss: 0.0343, Train Acc: 0.9767
  Val Loss: 1.9967, Val Acc: 0.1818

Epoch 26/50


                                                                                                    

Epoch 26 Summary: Duration: 2.82s
  Train Loss: 0.0218, Train Acc: 1.0000
  Val Loss: 2.1529, Val Acc: 0.2727

Epoch 27/50


                                                                                                    

Epoch 27 Summary: Duration: 2.76s
  Train Loss: 0.0162, Train Acc: 1.0000
  Val Loss: 2.2026, Val Acc: 0.1818

Epoch 28/50


                                                                                                    

Epoch 28 Summary: Duration: 2.84s
  Train Loss: 0.0216, Train Acc: 1.0000
  Val Loss: 2.1876, Val Acc: 0.3636

Epoch 29/50


                                                                                                    

Epoch 29 Summary: Duration: 2.80s
  Train Loss: 0.0267, Train Acc: 0.9767
  Val Loss: 1.9990, Val Acc: 0.5455

Epoch 30/50


                                                                                                    

Epoch 30 Summary: Duration: 2.78s
  Train Loss: 0.0246, Train Acc: 1.0000
  Val Loss: 1.9189, Val Acc: 0.4545

Epoch 31/50


                                                                                                    

Epoch 31 Summary: Duration: 2.74s
  Train Loss: 0.0063, Train Acc: 1.0000
  Val Loss: 1.8903, Val Acc: 0.4545

Epoch 32/50


                                                                                                    

Epoch 32 Summary: Duration: 2.80s
  Train Loss: 0.0121, Train Acc: 1.0000
  Val Loss: 2.0291, Val Acc: 0.4545

Epoch 33/50


                                                                                                    

Epoch 33 Summary: Duration: 2.76s
  Train Loss: 0.0062, Train Acc: 1.0000
  Val Loss: 2.0723, Val Acc: 0.4545

Epoch 34/50


                                                                                                    

Epoch 34 Summary: Duration: 2.81s
  Train Loss: 0.0163, Train Acc: 1.0000
  Val Loss: 2.1892, Val Acc: 0.5455

Epoch 35/50


                                                                                                    

Epoch 35 Summary: Duration: 2.72s
  Train Loss: 0.0051, Train Acc: 1.0000
  Val Loss: 2.3220, Val Acc: 0.4545

Epoch 36/50


                                                                                                    

Epoch 36 Summary: Duration: 2.78s
  Train Loss: 0.0079, Train Acc: 1.0000
  Val Loss: 2.2067, Val Acc: 0.4545

Epoch 37/50


                                                                                                    

Epoch 37 Summary: Duration: 2.82s
  Train Loss: 0.0046, Train Acc: 1.0000
  Val Loss: 2.2117, Val Acc: 0.4545

Epoch 38/50


                                                                                                    

Epoch 38 Summary: Duration: 2.75s
  Train Loss: 0.0037, Train Acc: 1.0000
  Val Loss: 2.2218, Val Acc: 0.4545

Epoch 39/50


                                                                                                    

Epoch 39 Summary: Duration: 2.76s
  Train Loss: 0.0050, Train Acc: 1.0000
  Val Loss: 2.2510, Val Acc: 0.4545

Epoch 40/50


                                                                                                    

Epoch 40 Summary: Duration: 2.75s
  Train Loss: 0.0048, Train Acc: 1.0000
  Val Loss: 2.2675, Val Acc: 0.4545

Epoch 41/50


                                                                                                    

Epoch 41 Summary: Duration: 2.79s
  Train Loss: 0.0052, Train Acc: 1.0000
  Val Loss: 2.2475, Val Acc: 0.4545

Epoch 42/50


                                                                                                    

Epoch 42 Summary: Duration: 2.78s
  Train Loss: 0.0041, Train Acc: 1.0000
  Val Loss: 2.2481, Val Acc: 0.4545

Epoch 43/50


                                                                                                    

Epoch 43 Summary: Duration: 2.79s
  Train Loss: 0.0127, Train Acc: 1.0000
  Val Loss: 2.3977, Val Acc: 0.4545

Epoch 44/50


                                                                                                    

Epoch 44 Summary: Duration: 2.75s
  Train Loss: 0.0251, Train Acc: 0.9767
  Val Loss: 2.3607, Val Acc: 0.2727

Epoch 45/50


                                                                                                    

Epoch 45 Summary: Duration: 2.73s
  Train Loss: 0.0040, Train Acc: 1.0000
  Val Loss: 2.3738, Val Acc: 0.3636

Epoch 46/50


                                                                                                    

Epoch 46 Summary: Duration: 2.81s
  Train Loss: 0.0042, Train Acc: 1.0000
  Val Loss: 2.3702, Val Acc: 0.2727

Epoch 47/50


                                                                                                    

Epoch 47 Summary: Duration: 2.72s
  Train Loss: 0.0338, Train Acc: 0.9767
  Val Loss: 2.3677, Val Acc: 0.3636

Epoch 48/50


                                                                                                    

Epoch 48 Summary: Duration: 2.79s
  Train Loss: 0.0095, Train Acc: 1.0000
  Val Loss: 2.0857, Val Acc: 0.3636

Epoch 49/50


                                                                                                    

Epoch 49 Summary: Duration: 2.74s
  Train Loss: 0.0049, Train Acc: 1.0000
  Val Loss: 2.2089, Val Acc: 0.3636

Epoch 50/50


                                                                                                    

Epoch 50 Summary: Duration: 2.78s
  Train Loss: 0.0036, Train Acc: 1.0000
  Val Loss: 2.2947, Val Acc: 0.4545

Swin3D Model (Balanced Data) Training Finished.


  eval_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))


Training curves saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\training_curves_swin3d_corrected_balanced.png

Evaluating Swin3D Model (Balanced Data) on Validation Set...
Loaded best Swin3D model from C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\swin3d_model_corrected_best.pth


                                                                                                    


Final Validation Loss (Swin3D Corrected, Balanced): 0.6792

--- Final Validation Metrics (Swin3D Corrected, Balanced) ---
Accuracy:  0.5455
Precision: 0.0000
Recall:    0.0000
F1-Score:  0.0000
AUC-ROC:   0.1667

Classification Report (Swin3D Corrected, Balanced):
                precision    recall  f1-score   support

Non-Cancer (0)       0.55      1.00      0.71         6
    Cancer (1)       0.00      0.00      0.00         5

      accuracy                           0.55        11
     macro avg       0.27      0.50      0.35        11
  weighted avg       0.30      0.55      0.39        11


Confusion Matrix (Swin3D Corrected, Balanced):
Confusion matrix plot saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\confusion_matrix_swin3d_corrected_balanced.png
ROC curve plot saved to C:\Users\rouaa\Documents\Final_Pneumatect\Preprocessed_Data_DSB_Swin3D_Corrected\roc_curve_swin3d_corrected_balanced.png

Script finished.
