**2.5D Stack Dataset Loader (with Optional Segmentation Masks)**

This `Dataset` class loads 2.5D CT stacks for classification (and, if available, segmentation):
- Loads each stack as a `[C, H, W]` tensor (C = stack size × 3).
- Optionally loads and returns a matching segmentation mask stack.
- Applies Albumentations transforms to both image and mask.
- Returns: (image stack, label, mask stack/dummy).


In [None]:
class Stack2p5DDataset(Dataset):
    """
    Loads 2.5D stacks of CT slices with optional segmentation masks.
    """
    def __init__(self, index_csv, npy_dir, mask_dir, transform=None):
        self.df = pd.read_csv(index_csv)
        self.npy_dir = npy_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        stack_files = eval(row['stack_npy_files'])  # List of .npy files for the stack

        # Load stack of 3-channel CT slices
        stack = [np.load(os.path.join(self.npy_dir, fname)) for fname in stack_files]  # Each: [3, H, W]
        stack = np.stack(stack, axis=0)                # [stack_size, 3, H, W]
        stack = stack.reshape(-1, stack.shape[2], stack.shape[3])  # [C, H, W], C = stack_size * 3

        # For augmentation: [H, W, C]
        stack_img = np.transpose(stack, (1, 2, 0))

        # Load segmentation mask if available
        mask_tensor = None
        if row['has_mask']:
            series_id = str(row['series_id'])
            mask_path = os.path.join(self.mask_dir, f"{series_id}.nii")
            mask_data = nib.load(mask_path).get_fdata()  # [H, W, num_slices]
            slice_indices = eval(row['stack_slice_ids'])
            mask_slices = [mask_data[:, :, idx] for idx in slice_indices]
            mask_stack = np.stack(mask_slices, axis=-1).astype(np.float32)  # [H, W, stack_size]
            # Albumentations expects [H, W, stack_size]
        else:
            mask_stack = None

        # Apply augmentation
        if self.transform:
            if mask_stack is not None:
                augmented = self.transform(image=stack_img, mask=mask_stack)
                stack_img = augmented['image']
                mask_stack = augmented['mask']
            else:
                stack_img = self.transform(image=stack_img)['image']

        # Convert image to [C, H, W] for PyTorch
        stack = np.transpose(stack_img, (2, 0, 1))
        label = int(row['label'])

        # Convert mask to tensor
        if mask_stack is not None:
            mask_tensor = torch.tensor(np.transpose(mask_stack, (2, 0, 1)), dtype=torch.float32)

        return torch.tensor(stack, dtype=torch.float32), torch.tensor(label, dtype=torch.long), mask_tensor

**2.5D ResNet18 Feature Extractor**

- **IMAGENET1K_V1** refers to the official ResNet-18 pretrained weights provided by torchvision (PyTorch's computer vision library).
- These weights are trained on the large-scale [ImageNet-1k dataset](https://www.image-net.org/), which contains 1,000 classes and over a million natural images.
- Using `weights="IMAGENET1K_V1"` is the modern way (torchvision ≥ 0.13) to enable pretrained models for transfer learning in PyTorch.

**References:**
- [torchvision.models.resnet18 Documentation](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html)
- [ImageNet-1k Dataset](https://www.image-net.org/)
- [Torchvision Model Weights Reference](https://pytorch.org/vision/stable/models.html#torchvision.models.ResNet18_Weights)


In [None]:
def build_2p5d_resnet18(in_channels=3, pretrained=False):
    """
    Creates a modified ResNet-18 backbone for 2.5D stacks.

    """
    # For torchvision >=0.13, use 'weights' arg for pretrained model.
    # 'IMAGENET1K_V1' means weights from training on the ImageNet-1k dataset.
    model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None)

    # Replace first conv layer to accept in_channels (e.g. 9 for 2.5D, not just 3)
    model.conv1 = nn.Conv2d(
        in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
    )

    # Remove average pooling and first conv layers; keep only convolutional backbone
    modules = list(model.children())[:-2]
    backbone = nn.Sequential(*modules)
    return backbone

**CNN+RNN Multi-Task Model for 2.5D CT Bleeding Detection**

This module defines a **multi-task deep learning model** for trauma CT analysis:
- **Input:** 2.5D stacks of multi-channel CT images ([B, T, C, H, W])
- **Backbone:** Modified ResNet-18 CNN for feature extraction
- **RNN:** LSTM aggregates stack-level features for patient-level classification
- **Segmentation Head:** Outputs per-slice (stack) segmentation masks (upsampled to original resolution)
- **Outputs:**
    - Patient-level bleeding classification logits `[B, num_classes]`
    - Stack-level segmentation mask logits `[B, T, H, W]`

Main Architecture Flow:
1. Each 2.5D stack is encoded using the CNN backbone.
2. The sequence of stack features is aggregated via a bidirectional LSTM.
3. The classification head predicts if the patient has bleeding.
4. The segmentation head predicts a binary mask for each stack (slice) in the sequence, resized to the input size.

**Key implementation features:**
- Uses `F.interpolate` to upsample segmentation mask logits to match input dimensions.
- Flexible and ready for multi-task learning with PyTorch.

In [None]:
class CNN_RNN_PatientClassifier_MTL(nn.Module):
    """
    Multi-task model:
      - CNN (ResNet18) extracts features from each stack.
      - LSTM aggregates features across stacks (for each patient).
      - Classifier predicts bleeding (patient-level).
      - Segmentation head predicts per-stack masks.
    """
    def __init__(self, in_channels=3, cnn_feature_dim=512, rnn_hidden=128, num_classes=2, rnn_layers=1, bidirectional=True):
        super().__init__()
        self.cnn = build_2p5d_resnet18(in_channels=in_channels)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.rnn = nn.LSTM(
            input_size=cnn_feature_dim,
            hidden_size=rnn_hidden,
            num_layers=rnn_layers,
            batch_first=True,
            bidirectional=bidirectional
        )
        rnn_output_dim = rnn_hidden * (2 if bidirectional else 1)
        self.classifier = nn.Linear(rnn_output_dim, num_classes)
        self.seg_head = nn.Conv2d(cnn_feature_dim, 1, kernel_size=1)

    def forward(self, x):
        """
        x: [B, T, C, H, W] - B: batch (patients), T: stacks per patient
        Returns:
            out: [B, num_classes] - patient-level logits
            seg_logits: [B, T, H, W] - segmentation logits for each stack
        """
        B, T, C, H, W = x.shape
        x_cnn = x.view(B * T, C, H, W)                        # [B*T, C, H, W]
        feats = self.cnn(x_cnn)                               # [B*T, cnn_feature_dim, h', w']

        # Segmentation prediction for each stack
        seg_logits_small = self.seg_head(feats)               # [B*T, 1, h', w']
        seg_logits = F.interpolate(seg_logits_small, size=(H, W), mode='bilinear', align_corners=False)  # [B*T, 1, H, W]

        # Feature aggregation for classification
        feats_gap = self.gap(feats).view(B * T, -1)           # [B*T, cnn_feature_dim]
        feats_rnn_input = feats_gap.view(B, T, -1)            # [B, T, cnn_feature_dim]
        rnn_out, _ = self.rnn(feats_rnn_input)                # [B, T, rnn_hidden*(2)]
        aggregated = torch.mean(rnn_out, dim=1)               # [B, rnn_hidden*(2)]
        out = self.classifier(aggregated)                     # [B, num_classes]

        seg_logits = seg_logits.view(B, T, 1, H, W).squeeze(2)
        return out, seg_logits

**PatientStackDataset: Loads All Stacks for a Patient**

It loads all 2.5D stacks for a patient, including associated segmentation masks, and applies augmentation. Returns tensors ready for model input.

In [None]:
# import ast

class PatientStackDataset(Dataset):
    """
    Loads all 2.5D stacks for a patient with optional segmentation masks.
    Returns tensors: (stacks, labels, masks)
    """
    def __init__(self, df, patient_ids, npy_dir, mask_dir, transform=None):
        # Ensure patient_id is the index for efficient lookup
        if 'patient_id' in df.columns:
            df = df.set_index('patient_id')
        self.df = df
        self.patient_ids = patient_ids
        self.npy_dir = npy_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]

        # Use .loc for safe indexing and handle single-row cases
        patient_df_slice = self.df.loc[[patient_id]]
        patient_df = patient_df_slice.sort_values(by='center_slice')

        stacks, labels, masks = [], [], []

        for _, row in patient_df.iterrows():
            try:
                # Safely evaluate the string representation of the list
                stack_filename = ast.literal_eval(row['stack_npy_files'])[0]
            except (ValueError, SyntaxError):
                continue # Skip if the entry is malformed

            npy_path = os.path.join(self.npy_dir, stack_filename)
            if not os.path.exists(npy_path):
                continue

            stack = np.load(npy_path)
            if stack.shape[0] == 3:
                stack = np.transpose(stack, (1, 2, 0))  # (H, W, C) for augmentation

            # Correctly load the mask based on the 'has_mask' flag
            mask = np.zeros((stack.shape[0], stack.shape[1], 1), dtype=np.float32)
            if row['has_mask']:
                # Construct the correct path using ONLY the series_id
                mask_series_path = os.path.join(self.mask_dir, f"{int(row['series_id'])}.npy")

                if os.path.exists(mask_series_path):
                    try:
                        # Load the full 3D mask volume and get the correct slice
                        full_mask_volume = np.load(mask_series_path)
                        center_slice_idx = int(row['center_slice'])

                        # Check bounds to prevent errors
                        if 0 <= center_slice_idx < full_mask_volume.shape[0]:
                            mask_slice = full_mask_volume[center_slice_idx]
                            if mask_slice.ndim == 2:
                                mask = np.expand_dims(mask_slice, axis=-1)
                            else:
                                mask = mask_slice # Already has channel dim
                    except Exception as e:
                        # If loading fails for any reason, default to a zero mask
                        print(f"Warning: Could not load mask slice for {row['series_id']}/{row['center_slice']}. Error: {e}")
                        pass # mask remains zeros

            if self.transform:
                augmented = self.transform(image=stack, mask=mask)
                stack = augmented['image']
                mask = augmented['mask']

            # Transpose to PyTorch format (C, H, W)
            stack = np.transpose(stack, (2, 0, 1))
            mask = np.transpose(mask, (2, 0, 1))

            stacks.append(stack)
            labels.append(row['label'])
            masks.append(mask)

        if not stacks:
            return torch.tensor([]), torch.tensor([]), torch.tensor([])

        stacks_tensor = torch.tensor(np.array(stacks), dtype=torch.float32)
        labels_tensor = torch.tensor(np.array(labels), dtype=torch.float32)
        masks_tensor = torch.tensor(np.array(masks), dtype=torch.float32)

        return stacks_tensor, labels_tensor, masks_tensor

**Dataset Initialization and Model Setup**

Defines file paths, prepares train/val/test datasets, and initializes the multi-task model.


In [None]:
# Data folder paths
npy_dir = '/content/drive/MyDrive/MSc_project/preproc_npy_3ch'
mask_dir = '/content/segmentations'

# Create patient-level datasets (with and without augmentation)
train_patient_dataset = PatientStackDataset(df, train_patient_ids, npy_dir, mask_dir, transform=train_transform)
val_patient_dataset   = PatientStackDataset(df, val_patient_ids, npy_dir, mask_dir, transform=None)
test_patient_dataset  = PatientStackDataset(df, test_patient_ids, npy_dir, mask_dir, transform=None)

print(f"Train patient dataset: {len(train_patient_dataset)} patients")
print(f"Validation patient dataset: {len(val_patient_dataset)} patients")
print(f"Test patient dataset: {len(test_patient_dataset)} patients")

# Set computation device and initialize the multi-task model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN_RNN_PatientClassifier_MTL(in_channels=3).to(device)

**Custom Collate Function for Patient-Level Batching**

This function pads each patient's sequence of stacks and masks to the same length for batching.


In [None]:
# from torch.nn.utils.rnn import pad_sequence

def collate_fn_padd_with_mask(batch):
    # Remove empty data points
    batch = [item for item in batch if item[0].nelement() > 0]
    if not batch:
        return torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([])

    # Split batch into images, labels, masks
    batch_x   = [item[0] for item in batch]
    batch_y   = [item[1] for item in batch]
    batch_mask = [item[2] for item in batch]

    # Sequence lengths for each patient
    lengths = torch.tensor([len(item) for item in batch_x])

    # Pad sequences for batching:
    padded_batch_x = pad_sequence(batch_x, batch_first=True, padding_value=0)
    padded_batch_y = pad_sequence(batch_y, batch_first=True, padding_value=0)
    padded_batch_mask = pad_sequence(batch_mask, batch_first=True, padding_value=0)

    return padded_batch_x, padded_batch_y, padded_batch_mask, lengths

**DataLoader Instantiation**

Creates DataLoaders for training, validation, and testing, using the custom collate function for padded patient sequences.

In [None]:
# Instantiate Dataloader
train_loader = DataLoader(
    train_patient_dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=collate_fn_padd_with_mask, pin_memory=True
)
val_loader = DataLoader(
    val_patient_dataset, batch_size=2, shuffle=False, num_workers=0,
    collate_fn=collate_fn_padd_with_mask, pin_memory=True
)
test_loader = DataLoader(
    test_patient_dataset, batch_size=2, shuffle=False, num_workers=0,
    collate_fn=collate_fn_padd_with_mask, pin_memory=True
)

**DataLoader Shape Check**

Checks the output shapes from the DataLoader to ensure batching and padding are correct.

In [None]:
print("Checking the output shape of the train_loader...")

for batch_x, batch_y, batch_mask, lengths in train_loader:
    print("\nSuccessfully loaded one batch!")
    print('  - Images Batch Shape  (B, S_max, C, H, W):', batch_x.shape)
    print('  - Labels Batch Shape  (B, S_max):', batch_y.shape)
    print('  - Masks Batch Shape   (B, S_max, 1, H, W):', batch_mask.shape)
    print('  - Lengths Batch Shape (B):', lengths.shape)
    print('  - Original sequence lengths in this batch:', lengths.tolist())
    break

print("\nShape check complete. The DataLoader is working correctly.")

**EarlyStopping Utility**

Implements early stopping to halt training when validation performance stops improving, and saves the best model weights.


In [None]:
class EarlyStopping:
    def __init__(self, patience=7, mode='max'):
        self.patience = patience
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_model_wts = None

    def __call__(self, score, model):
        if self.best_score is None:
            self.best_score = score
            self.best_model_wts = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        elif (self.mode == 'max' and score > self.best_score) or (self.mode == 'min' and score < self.best_score):
            self.best_score = score
            self.counter = 0
            self.best_model_wts = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

**Metric Computation Utility**

Calculates accuracy, sensitivity, specificity, and AUC for model evaluation.

In [None]:
# from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score

def compute_metrics(y_true, y_pred, y_prob):
    """
    Computes standard classification metrics given true and predicted labels/probabilities.
    Returns:
        acc:         Accuracy
        sensitivity: Recall for positives (TPR)
        specificity: Recall for negatives (TNR)
        AUC:         Area under the ROC curve
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    acc = accuracy_score(y_true, y_pred)
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

    #Add a check
    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = 0.5

    return acc, sensitivity, specificity, auc

**Segmentation Loss Functions**

Defines the loss functions for segmentation: Binary Cross Entropy (BCE), Dice loss, and a weighted combination of both.

In [None]:
# Binary Cross Entropy loss with logits (for segmentation masks)
bce_loss_fn = nn.BCEWithLogitsLoss(reduction='mean')

def dice_loss(pred, target, eps=1e-6):
    """
    Computes Dice loss for segmentation.
    """
    pred = torch.sigmoid(pred)
    target = target.float()
    pred = pred.contiguous().view(pred.size(0), -1)
    target = target.contiguous().view(target.size(0), -1)
    intersection = (pred * target).sum(dim=1)
    union = pred.sum(dim=1) + target.sum(dim=1)
    dice = (2. * intersection + eps) / (union + eps)
    return 1 - dice.mean()

def combo_loss(pred, target, bce_weight=0.5, dice_weight=0.5):
    """
    Weighted sum of BCE and Dice loss for segmentation.
    """
    bce = bce_loss_fn(pred, target)
    dice = dice_loss(pred, target)
    return bce_weight * bce + dice_weight * dice

**Loss Functions for Training**

Defines the loss functions for both segmentation and classification tasks.

In [None]:
# Segmentation loss: Binary Cross-Entropy with logits (for organ mask prediction)
seg_criterion = nn.BCEWithLogitsLoss(reduction='mean')

# Classification loss: Cross-Entropy (for patient-level bleeding detection)
cls_criterion = nn.CrossEntropyLoss()

**Training Function for One Epoch**

Trains the model for one epoch, handling both classification and segmentation, and computes metrics.

In [None]:
def train_one_epoch(model, loader, cls_criterion, seg_criterion, optimizer, device, epoch, num_epochs, seg_weight=1.0):
    model.train()
    losses, seg_losses, cls_losses = [], [], []
    all_y, all_logits = [], []

    progress = tqdm(loader, desc=f"Epoch {epoch+1} [Train Patient]", leave=False)
    for batch_x, batch_y, batch_mask, lengths in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        logits, seg_logits = model(batch_x)

        # True patient label: max label across the sequence (any stack with bleeding = patient positive)
        true_patient_labels = torch.max(batch_y, dim=1).values.long()
        cls_loss = cls_criterion(logits, true_patient_labels)
        cls_losses.append(cls_loss.item())

        # Segmentation loss: only for stacks with non-empty masks
        seg_loss = torch.tensor(0.0, device=device)
        n_segs = 0
        for b in range(batch_x.shape[0]):
            true_length = lengths[b]
            for t in range(true_length):
                target_mask = batch_mask[b, t]
                pred_mask = seg_logits[b, t]
                if target_mask.max() > 0:
                    seg_loss += combo_loss(pred_mask, target_mask)
                    n_segs += 1
        if n_segs > 0:
            seg_loss = seg_loss / n_segs
        seg_losses.append(seg_loss.item())

        total_loss = cls_loss + seg_weight * seg_loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        losses.append(total_loss.item())

        all_y.append(true_patient_labels.cpu().numpy())
        all_logits.append(torch.softmax(logits, 1)[:, 1].detach().cpu().numpy())
        progress.set_postfix({"loss": f"{total_loss.item():.4f}", "cls": f"{cls_loss.item():.4f}", "seg": f"{seg_loss.item():.4f}"})

    all_y = np.concatenate(all_y)
    all_probs = np.concatenate(all_logits)
    all_pred = (all_probs >= 0.5).astype(int)
    acc, sens, spec, auc = compute_metrics(all_y, all_pred, all_probs)
    return np.mean(losses), np.mean(cls_losses), np.mean(seg_losses), acc, sens, spec, auc

**Validation Function**

Evaluates the model on the validation set

In [None]:
@torch.no_grad()
def validate(model, loader, cls_criterion, seg_criterion, device, epoch, num_epochs, seg_weight=1.0):
    model.eval()
    losses, seg_losses, cls_losses = [], [], []
    all_y, all_logits = [], []

    progress = tqdm(loader, desc=f"Epoch {epoch+1} [Val Patient]", leave=False)
    for batch_x, batch_y, batch_mask, lengths in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        logits, seg_logits = model(batch_x)

        # True patient label from max across sequence
        true_patient_labels = torch.max(batch_y, dim=1).values.long()
        cls_loss = cls_criterion(logits, true_patient_labels)
        cls_losses.append(cls_loss.item())

        # Segmentation loss
        seg_loss = torch.tensor(0.0, device=device)
        n_segs = 0
        for b in range(batch_x.shape[0]):
            true_length = lengths[b]
            for t in range(true_length):
                target_mask = batch_mask[b, t]
                pred_mask = seg_logits[b, t]
                if target_mask.max() > 0:
                    seg_loss += combo_loss(pred_mask, target_mask)
                    n_segs += 1
        if n_segs > 0:
            seg_loss = seg_loss / n_segs
        seg_losses.append(seg_loss.item())

        total_loss = cls_loss + seg_weight * seg_loss
        losses.append(total_loss.item())

        all_y.append(true_patient_labels.cpu().numpy())
        all_logits.append(torch.softmax(logits, 1)[:, 1].detach().cpu().numpy())
        progress.set_postfix({"loss": f"{total_loss.item():.4f}", "cls": f"{cls_loss.item():.4f}", "seg": f"{seg_loss.item():.4f}"})

    all_y = np.concatenate(all_y)
    all_probs = np.concatenate(all_logits)
    all_pred = (all_probs >= 0.5).astype(int)
    acc, sens, spec, auc = compute_metrics(all_y, all_pred, all_probs)
    return np.mean(losses), np.mean(cls_losses), np.mean(seg_losses), acc, sens, spec, auc

**Model Training Loop**

Main training loop for the multi-task model, including early stopping, logging, and history saving.

In [None]:
def run_training(
    model, train_loader, val_loader, device, num_epochs=30, lr=1e-4,
    seg_weight=1.0, patience=7, ckpt_path=None, hist_path=None
):
    """
    Main training loop for the multi-task model (classification + segmentation).
    Tracks training/validation losses and metrics, applies early stopping, and saves history for plotting.
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    early_stopper = EarlyStopping(patience=patience, mode='max')

    history = {
        'train_loss': [], 'train_cls_loss': [], 'train_seg_loss': [], 'train_acc': [], 'train_auc': [],
        'val_loss': [], 'val_cls_loss': [], 'val_seg_loss': [], 'val_acc': [], 'val_auc': []
    }

    for epoch in range(num_epochs):
        train_loss, train_cls_loss, train_seg_loss, train_acc, train_sens, train_spec, train_auc = train_one_epoch(
            model, train_loader, cls_criterion, seg_criterion, optimizer, device, epoch, num_epochs, seg_weight)
        val_loss, val_cls_loss, val_seg_loss, val_acc, val_sens, val_spec, val_auc = validate(
            model, val_loader, cls_criterion, seg_criterion, device, epoch, num_epochs, seg_weight)

        history['train_loss'].append(train_loss)
        history['train_cls_loss'].append(train_cls_loss)
        history['train_seg_loss'].append(train_seg_loss)
        history['train_acc'].append(train_acc)
        history['train_auc'].append(train_auc)
        history['val_loss'].append(val_loss)
        history['val_cls_loss'].append(val_cls_loss)
        history['val_seg_loss'].append(val_seg_loss)
        history['val_acc'].append(val_acc)
        history['val_auc'].append(val_auc)

        print(
            f"Epoch {epoch+1}/{num_epochs} | "
            f"Train Loss: {train_loss:.4f} (Acc: {train_acc:.4f}, AUC: {train_auc:.4f}) | "
            f"Val Loss: {val_loss:.4f} (Acc: {val_acc:.4f}, AUC: {val_auc:.4f}) | "
            f"Val Sens: {val_sens:.4f} | Val Spec: {val_spec:.4f}"
        )

        early_stopper(val_auc, model)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    print(f"\nBest validation AUC: {early_stopper.best_score:.4f}")
    model.load_state_dict(early_stopper.best_model_wts)
    if ckpt_path:
        torch.save(model.state_dict(), ckpt_path)
        print(f"Best model saved to {ckpt_path}")
    if hist_path:
        with open(hist_path, 'wb') as f:
            pickle.dump(history, f)
        print(f"Training history saved to {hist_path}")

    return model, history

print("Starting the patient-level multi-task training...")

trained_model, training_history = run_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=30,
    lr=1e-4,
    seg_weight=1.0,
    patience=7,
    ckpt_path='/content/drive/MyDrive/MSc_project/best_patient_model_mtl.pth',
    hist_path='/content/drive/MyDrive/MSc_project/training_history.pkl'  # Training history is saved here
)

print("\nPatient-level multi-task training completed.")