In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
print(torch.__version__)
from skimage.metrics import structural_similarity as ssim  # Import SSIM function
import os
import numpy as np
from glob import glob
import random
import imageio
from torchvision import transforms
from collections import defaultdict
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import sys
sys.path.append("D:/pytorch")  # Adds the parent directory to the system path
import segmentation_models_pytorch as smp
from ipywidgets import widgets, VBox, HBox, IntText, ToggleButtons, Button, IntSlider, Label, Checkbox, Layout
from IPython.display import display, clear_output
import pandas as pd
from sklearn.metrics import confusion_matrix
import shutil
import gc
import traceback
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

# Use the widget backend for better interactivity in Jupyter
%matplotlib widget

%config InlineBackend.figure_format = 'retina'
get_ipython().history_manager.enabled = False
%config InlineBackend.close_figures = True

!jupyter nbextension enable --py --sys-prefix ipympl
!jupyter nbextension enable --py --sys-prefix widgetsnbextension

In [None]:
BASE_DIR = "D:/SPARSE/"
train_imgs = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/image_train/*tiff"))
train_masks = glob(os.path.join(BASE_DIR, "MULTICLASS/SPARSE/class_train/*png"))

val_imgs = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/image_val/*tiff"))
val_masks = glob(os.path.join(BASE_DIR, "COMBINED_MASKS3/*png"))
#val_masks = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/class_val/*png"))

#val_imgs = glob(os.path.join(BASE_DIR, "CAR/D_DATASET/image_val/*tiff"))
#val_masks = glob(os.path.join(BASE_DIR, "CAR/D_DATASET/class_val/*png"))

In [None]:
def pix_count(mask_images, ignore_index=None):
    """Counts pixels per class in a list of mask images."""
    pixel_counts = defaultdict(int)

    for mask_path in mask_images:
        mask = imageio.imread(mask_path).astype(int)
        unique, counts = np.unique(mask, return_counts=True)
        
        # Dynamically apply ignore filtering
        valid_pixels = {val: count for val, count in zip(unique, counts) if val != ignore_index}
        for key, val in valid_pixels.items():
            pixel_counts[key] += val

    num_classes = max(pixel_counts.keys(), default=-1) + 1  # Handle empty cases
    return dict(pixel_counts), num_classes

pixel_counts, num_classes = pix_count(mask_images=train_masks, ignore_index=5)

print(pixel_counts)
print("Number of classes:", num_classes)

In [None]:
class CustomDataset(data.Dataset):
    def __init__(self, image_paths, target_paths, transform=None, transform_label=None, is_validation=False):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transform = transform
        self.transform_label = transform_label
        self.is_validation = is_validation
        
    def __getitem__(self, index):
        image = imageio.imread(self.image_paths[index])
        image = np.asarray(image, dtype='float32')
        mask = imageio.imread(self.target_paths[index])
        mask = np.asarray(mask, dtype='int64')
        mask[mask > 5] = 5
        
        #if self.is_validation:
        #    mask[mask == 2] = 0
        
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        
        if self.transform:
            image = self.transform(image)
            
        random.seed(seed)
        torch.manual_seed(seed)
        if self.transform_label:
            mask = self.transform_label(mask)
            mask = mask.squeeze(0)

        return image, mask

    def __len__(self):
        return len(self.image_paths)

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = CustomDataset(train_imgs, train_masks, transform=transform, transform_label=transform)
val_dataset = CustomDataset(val_imgs, val_masks, transform=transforms.ToTensor(), transform_label=None, is_validation=True)

In [None]:
# Define model parameters
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['class_0', 'class_1', 'class_2', 'class_3', 'class_4']  # Define the 5 classes
ACTIVATION = 'softmax'  # Use softmax for multiclass segmentation
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available, else CPU

# Initialize the U-Net model with the specified encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES),
    activation=ACTIVATION,
    in_channels=3
)

# Get the preprocessing function for the chosen encoder
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
# Define training parameters
LEARNING_RATE = 0.0001

# Choose the loss function. DiceLoss is commonly used for segmentation tasks.
#d_loss = smp.utils.losses.DiceLoss(ignore_index=5)
d_loss = smp.utils.losses.CrossEntropyLoss()
loss = smp.utils.losses.DynamicWeightedConfidenceMulticlassDiceLoss(
    eps=1.0,
    beta=1.0,
    amplification_factor=5.0,
    confidence_threshold=0.8,
    correctness_threshold=0.5,
    activation="softmax",
    ignore_index=5  # Ensure ignore_index is handled properly
)

# Define the metric for evaluation. IoU (Intersection over Union) is a standard metric for segmentation.
metrics = [smp.utils.metrics.mIoU()]

# Initialize the optimizer. Adam is a popular choice for deep learning tasks.
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=d_loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False, num_workers=0)
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=0)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=20, shuffle=False, num_workers=0)

### Training Loop

In this cell, we're executing the training loop for our segmentation model:

- **Epochs**: We're training the model for a total of `100` epochs. In each epoch, the model is trained on the entire training dataset and then validated on the validation dataset.

- **Monitoring Performance**: After each epoch, we print the Dice Loss for both training and validation to monitor the model's performance.

- **Model Saving**: If the validation performance (measured by Dice Loss) improves compared to previous epochs, we save the model's state dictionary. This way, we ensure that we retain the model weights that give the best performance on the validation set. The model is saved as `T1_Car2.pth`.

By the end of this loop, we aim to have a model that performs well on the validation dataset, indicating its potential to generalize well to unseen data.

In [None]:
# Configuration dictionary for paths and settings
BASE_DIR = "D:/SPARSE/"
config = {
    "base_dir": BASE_DIR,
    "data_folder": os.path.join(BASE_DIR, "SESSIONS/SESSIONS_MULTI/Session_v2"),
    "epochs_per_iteration": 300,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

train_imgs = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/image_train/*tiff"))
train_masks = glob(os.path.join(BASE_DIR, "MULTICLASS/SPARSE/class_train/*png"))

val_imgs = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/image_val/*tiff"))
val_masks = glob(os.path.join(BASE_DIR, "COMBINED_MASKS3/*png"))
#val_masks = glob(os.path.join(BASE_DIR, "MULTICLASS/DENSE/class_val/*png"))

filename_map = {i: os.path.basename(mask) for i, mask in enumerate(train_masks)}

# Set up paths within the main session data folder for iteration storage
config["dataset_folder"] = os.path.join(config["data_folder"], "datasets")
config["model_folder"] = os.path.join(config["data_folder"], "models")
config["metrics_history_path"] = os.path.join(config["data_folder"], "metrics_history.csv")

# Ensure directory structure is created if it doesn't exist
os.makedirs(config["dataset_folder"], exist_ok=True)
os.makedirs(config["model_folder"], exist_ok=True)

# Widget controls setup
transparency_slider = widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='Transparency:', continuous_update=True)
label_selector = ToggleButtons(options=['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4'])

forward_button = Button(description=">")
backward_button = Button(description="<")
continue_training_button = Button(description="Continue Training", layout=Layout(margin='20px 0px 0px 0px'))
epoch_slider = IntSlider(value=config["epochs_per_iteration"], min=1, max=100, description="Epochs:")
prediction_value_label = Label(value="Prediction: ")
coordinates_label = Label(value="Coordinates: ")
image_name_label = Label(value="Image: ")
labeled_checkbox = Checkbox(description="Labeled", value=False)

# Initialize data structures
predictions_cache = {}
selected_points = defaultdict(list)
changed_masks = set()
images_labeled = set()
current_image_index = 0
current_iteration_points = defaultdict(list)
min_dice_loss = float('inf')
max_iou=0.0
points_collected_total = 0
dragging_point = None
drag_start = None
event_connections = []
overlay = None

In [None]:
class DatasetManager:
    def __init__(self, config, train_imgs, train_masks, val_imgs, val_masks):
        self.config = config
        self.train_imgs = train_imgs
        self.train_masks = train_masks
        self.val_imgs = val_imgs
        self.val_masks = val_masks
        self.session_manager = SessionManager(config)

        # Determine the dataset for initialization
        latest_iteration = self.session_manager.get_current_iteration() - 1
        if latest_iteration == 0:
            self.modified_mask_dir = self.get_modified_mask_dir(0)
            self._setup_initial_masks()
        else:
            self.modified_mask_dir = self.get_modified_mask_dir(latest_iteration)

    def validate_masks(self, iteration):
        """Ensure all masks for the specified iteration exist."""
        iteration_dir = self.get_modified_mask_dir(iteration)
        for img_id, mask_name in filename_map.items():
            mask_path = os.path.join(iteration_dir, mask_name)
            if not os.path.exists(mask_path):
                raise FileNotFoundError(f"Missing mask for img_id {img_id} in iteration {iteration}")

    def initialize_selected_points(self, selected_points):
        """Load labeled points from existing masks while ensuring deleted points are not reloaded."""
        latest_iteration_dir = self.modified_mask_dir
        print(f"[INFO] Initializing selected points from masks in {latest_iteration_dir}...")

        for img_id, mask_name in filename_map.items():
            mask_path = os.path.join(latest_iteration_dir, mask_name)
            if os.path.exists(mask_path):
                mask = imageio.imread(mask_path)

                # Extract labeled pixels (0-4), ignore class 5 (deleted)
                labeled_positions = np.column_stack(np.where((mask >= 0) & (mask <= 4)))
                labels = mask[labeled_positions[:, 0], labeled_positions[:, 1]]

                # ✅ Store only valid points, skipping deleted pixels (class 5)
                selected_points[img_id] = [
                    (y, x, int(label))
                    for (y, x), label in zip(labeled_positions, labels)
                ]

        print(f"[INFO] Loaded selected points for {len(selected_points)} images.")

    def get_modified_mask_dir(self, iteration):
        """Return the directory for a specific iteration's masks."""
        return os.path.join(self.config["dataset_folder"], f"iteration_{iteration}")
    
    def _setup_initial_masks(self):
        """Set up initial masks in iteration_0."""
        os.makedirs(self.modified_mask_dir, exist_ok=True)
        for mask in self.train_masks:
            shutil.copy(mask, os.path.join(self.modified_mask_dir, os.path.basename(mask)))

    def save_all_masks(self, current_iteration):
        """
        Save updated masks, ensuring that deletions are persisted.
        Automatically retrieves selected_points and changed_masks from the DatasetManager.
        """
        iteration_dir = self.get_modified_mask_dir(current_iteration)
        previous_iteration_dir = self.get_modified_mask_dir(current_iteration - 1)

        # Ensure the new iteration directory exists
        os.makedirs(iteration_dir, exist_ok=True)

        modified_count = 0  # Counter for modified masks

        for img_id, mask_name in filename_map.items():
            previous_mask_path = os.path.join(previous_iteration_dir, mask_name)
            current_mask_path = os.path.join(iteration_dir, mask_name)

            # ✅ **Copy previous mask to ensure it exists**
            shutil.copy(previous_mask_path, current_mask_path)

            # ✅ **Modify only if the mask was changed**
            if img_id in changed_masks:
                mask = imageio.imread(current_mask_path)

                # ✅ **Remove deleted points by setting them to class 5 (unlabeled)**
                existing_points = {(y, x) for y, x, _ in selected_points[img_id]}
                for y in range(mask.shape[0]):
                    for x in range(mask.shape[1]):
                        if (y, x) not in existing_points and mask[y, x] != 5:
                            mask[y, x] = 5  # Set deleted points to unlabeled (5)

                # ✅ **Apply newly labeled points**
                for y, x, label in selected_points[img_id]:
                    mask[y, x] = label

                # ✅ **Save the modified mask**
                imageio.imsave(current_mask_path, mask)
                modified_count += 1

        print(f"[INFO] Copied {len(filename_map)} masks from iteration {current_iteration - 1}.")
        print(f"[INFO] Modified and saved {modified_count} masks for iteration {current_iteration}.")

        # ✅ **Clear changed masks tracking**
        changed_masks.clear()

    def count_class_pixels(self, iteration_dir):
        """Count the number of pixels for each class (0-4) in the masks of the given iteration."""
        class_counts = {i: 0 for i in range(5)}  # Initialize count for classes 0-4

        for mask_file in os.listdir(iteration_dir):
            mask_path = os.path.join(iteration_dir, mask_file)
            mask = imageio.imread(mask_path)

            # Count pixels for each class
            for i in range(5):  # Ignore class 5
                class_counts[i] += np.sum(mask == i)

        return class_counts

class SessionManager:
    def __init__(self, config):
        self.config = config
        self.current_iteration = self._get_latest_iteration() + 1  # Start from the next iteration
        self.metrics_history = (
            pd.read_csv(config["metrics_history_path"])
            if os.path.exists(config["metrics_history_path"])
            else pd.DataFrame(columns=[
                'Iteration', 'Total_TP', 'Total_FP', 'Total_FN',  # Summed TP, FP, FN across classes
                'mPrecision', 'mRecall', 'mF1-Score', 'mIoU'  # Mean precision, recall, F1-score, IoU
            ] + [f'Class_{i}_{metric}' for i in range(5) for metric in ['TP', 'FP', 'FN', 'Precision', 'Recall', 'F1-Score', 'IoU']])
        )

    def _get_latest_iteration(self):
        """Determine the latest iteration based on available datasets."""
        dataset_folder = self.config["dataset_folder"]
        iteration_folders = [
            int(folder.split("_")[-1])
            for folder in os.listdir(dataset_folder)
            if folder.startswith("iteration_") and folder.split("_")[-1].isdigit()
        ]
        return max(iteration_folders, default=0)  # Default to 0 if no iterations exist

    def increment_iteration(self):
        """Increment the iteration counter."""
        self.current_iteration += 1

    def get_current_iteration(self):
        """Return the current iteration."""
        return self.current_iteration

    def calculate_metrics_from_confusion_matrix(self, y_true, y_pred, num_classes=5):
        """Calculate evaluation metrics for multiclass segmentation."""

        # Compute confusion matrix for all classes (0-4)
        cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

        # Extract per-class TP, FP, FN, TN
        tp = np.diag(cm)  # True Positives for each class
        fn = cm.sum(axis=1) - tp  # False Negatives
        fp = cm.sum(axis=0) - tp  # False Positives

        # Compute precision, recall, and F1-score for each class
        precision, recall, f1_score, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, labels=list(range(num_classes))
        )

        # Compute per-class IoU
        iou_per_class = tp / (tp + fp + fn + 1e-8)  # Avoid division by zero
        mean_iou = np.mean(iou_per_class)  # Mean IoU across all classes

        # Store metrics in a dictionary
        metrics_dict = {
            "TP": tp.tolist(),
            "FP": fp.tolist(),
            "FN": fn.tolist(),
            "Precision": precision.tolist(),
            "Recall": recall.tolist(),
            "F1-Score": f1_score.tolist(),
            "IoU": iou_per_class.tolist(),
            "Mean IoU": mean_iou
        }

        return metrics_dict

    def update_metrics_history(self, metrics_dict):
        """Update metrics history and save it to a CSV file."""

        # Compute totals across all classes
        total_tp = sum(metrics_dict["TP"])
        total_fp = sum(metrics_dict["FP"])
        total_fn = sum(metrics_dict["FN"])

        # Compute mean metrics
        mean_precision = np.mean(metrics_dict["Precision"])
        mean_recall = np.mean(metrics_dict["Recall"])
        mean_f1_score = np.mean(metrics_dict["F1-Score"])
        mean_iou = metrics_dict["Mean IoU"]

        # Prepare the metrics dictionary
        metrics_data = {
            "Iteration": self.current_iteration,
            "Total_TP": total_tp, "Total_FP": total_fp, "Total_FN": total_fn,
            "mPrecision": mean_precision, "mRecall": mean_recall, "mF1-Score": mean_f1_score, "mIoU": mean_iou
        }

        # Store per-class metrics
        for i in range(5):
            metrics_data[f'Class_{i}_TP'] = metrics_dict["TP"][i]
            metrics_data[f'Class_{i}_FP'] = metrics_dict["FP"][i]
            metrics_data[f'Class_{i}_FN'] = metrics_dict["FN"][i]
            metrics_data[f'Class_{i}_Precision'] = metrics_dict["Precision"][i]
            metrics_data[f'Class_{i}_Recall'] = metrics_dict["Recall"][i]
            metrics_data[f'Class_{i}_F1-Score'] = metrics_dict["F1-Score"][i]
            metrics_data[f'Class_{i}_IoU'] = metrics_dict["IoU"][i]

        # Convert to DataFrame
        new_metrics = pd.DataFrame([metrics_data])

        # Update history
        self.metrics_history = pd.concat([self.metrics_history, new_metrics], ignore_index=True)

        # Save to disk
        self.metrics_history.to_csv(self.config["metrics_history_path"], index=False)
        print("[INFO] Metrics history updated and saved.")


# ModelManager
class ModelManager:
    def __init__(self, config, model):
        self.config = config
        self.model = model
        
    def save_model(self, iteration, is_best=False):
        """Save the model for a specific iteration."""
        if iteration == "best_model":
            model_path = os.path.join(self.config["model_folder"], "best_model.pth")
        else:
            model_path = os.path.join(self.config["model_folder"], f"iteration_{iteration}.pth")

        torch.save(self.model.state_dict(), model_path)
        if is_best:
            print(f"[INFO] Updated global best model at: {model_path}")
        else:
            print(f"[INFO] Model saved for iteration {iteration} at: {model_path}")

    def load_model(self, iteration):
        """Load the model for a specific iteration."""
        model_path = os.path.join(self.config["model_folder"], f"iteration_{iteration}.pth")
        if os.path.exists(model_path):
            print(f"[INFO] Loading model from: {model_path}")
            self.model.load_state_dict(torch.load(model_path))
        else:
            print(f"[ERROR] Model not found for iteration {iteration}: {model_path}")
            raise FileNotFoundError(f"Model not found: {model_path}")

    def cache_predictions(self, train_loader, device="cpu"):
        """Cache predictions for active learning steps."""
        #print(f"[DEBUG] Caching predictions using the current model...")
        self.model.eval()  # Set the model to evaluation mode

        predictions_cache = {}
        with torch.no_grad():
            for batch_idx, (inputs, _) in enumerate(train_loader):
                inputs = inputs.to(device)
                outputs = self.model(inputs).cpu().numpy()
                for i, output in enumerate(outputs):
                    img_id = batch_idx * train_loader.batch_size + i
                    predictions_cache[img_id] = {
                        "input": np.transpose(inputs[i].cpu().numpy(), (1, 2, 0)),
                        "prediction": output.squeeze()
                    }
        print(f"[INFO] Cached predictions for {len(predictions_cache)} images.")
        return predictions_cache

In [None]:
def update_counts(img_id):
    class_counts = {i: 0 for i in range(5)}  # Initialize counts for classes 0-4

    for _, _, lbl in selected_points[img_id]:
        if lbl in class_counts:
            class_counts[lbl] += 1

    for i in range(5):
        print(f"Class {i} Points: {class_counts[i]}")

def find_nearest_point(img_id, x, y, threshold=2):
    for i, (py, px, _) in enumerate(selected_points[img_id]):
        if abs(px - x) < threshold and abs(py - y) < threshold:
            return i
    return None

def delete_nearest_point(img_id, x, y):
    """
    Deletes the nearest point in selected_points for a given image and coordinates.
    Ensures the deletion is reflected in the saved masks.
    """
    point_idx = find_nearest_point(img_id, x, y)
    
    if point_idx is not None:
        # Extract deleted point
        deleted_point = selected_points[img_id][point_idx]
        y_coord, x_coord, label = deleted_point

        # ✅ DEBUG: Before deletion
        print(f"[DEBUG] Before Deletion - Points in img {img_id}: {selected_points[img_id]}")

        # Remove from selected_points
        selected_points[img_id].pop(point_idx)

        # Also remove from `current_iteration_points`
        current_iteration_points[img_id] = [
            pt for pt in current_iteration_points[img_id] 
            if not (pt[0] == y_coord and pt[1] == x_coord)
        ]

        # ✅ Ensure the mask is marked as changed
        changed_masks.add(img_id)

        # ✅ DEBUG: After deletion
        print(f"[INFO] Deleted point at ({x_coord}, {y_coord}) for Class {label} in image {img_id}")
        print(f"[DEBUG] After Deletion - Remaining Points: {selected_points[img_id]}")

        # ✅ Modify the mask directly
        save_current_mask(delete_mode=True, img_id=img_id)
    else:
        print(f"[WARNING] No point found near ({x}, {y}) in image {img_id}")


In [None]:
def plot_all_points(ax, img_id):
    # Define colors for each class (adjust as needed)
    class_colors = ['blue', 'green', 'yellow', 'orange', 'purple']

    for y, x, label in selected_points[img_id]:
        color = class_colors[label]  # Get color based on class index
        ax.plot(x, y, 'o', color=color, markersize=4)  # Increase size for visibility

# Event handlers for adding, dragging, and highlighting points
def on_click(event):
    global dragging_point
    if event.inaxes:
        x, y = int(event.xdata), int(event.ydata)
        nearest_point_idx = find_nearest_point(current_image_index, x, y)

        if nearest_point_idx is not None and event.button == 1:  # Left-click to start dragging
            dragging_point = nearest_point_idx
        elif event.button == 3:  # Right-click to delete the point
            delete_nearest_point(current_image_index, x, y)
            redraw_image_with_points(event.inaxes)
        else:
            # Extract the selected class index (e.g., 'Class 2' → 2)
            label = int(label_selector.value.split()[-1])

            # ✅ **Ensure changes are tracked**
            selected_points[current_image_index].append((y, x, label))
            current_iteration_points[current_image_index].append((y, x, label))
            changed_masks.add(current_image_index)  # ✅ **Mark image as changed**
            
            # ✅ **Ensure UI and Data Update**
            update_counts(current_image_index)
            images_labeled.add(current_image_index)
            labeled_checkbox.value = True

            # ✅ **Save immediately**
            save_current_mask()

            # Redraw the image with the new point
            redraw_image_with_points(event.inaxes)

            
def on_motion(event):
    global dragging_point
    if event.inaxes:
        x, y = int(event.xdata), int(event.ydata)

        # Get predicted class at (y, x)
        predicted_class = predictions_cache[current_image_index]["prediction"].argmax(axis=0)[y, x]
        prediction_value_label.value = f"Prediction: Class {predicted_class}"
        coordinates_label.value = f"Coordinates: ({x}, {y})"

        # Check if we're near an existing point
        nearest_point_idx = find_nearest_point(current_image_index, x, y)

        # Highlight nearest point if hovering
        if nearest_point_idx is not None and dragging_point is None:
            redraw_image_with_points(event.inaxes, highlight_idx=nearest_point_idx)
        elif dragging_point is None:
            redraw_image_with_points(event.inaxes)

        # If dragging a point, update its position
        if dragging_point is not None:
            old_y, old_x, label = selected_points[current_image_index][dragging_point]
            
            # Update the dragged point's coordinates
            selected_points[current_image_index][dragging_point] = (y, x, label)
            
            # Ensure the same point in `current_iteration_points` is updated
            for i, (py, px, lbl) in enumerate(current_iteration_points[current_image_index]):
                if py == old_y and px == old_x and lbl == label:
                    current_iteration_points[current_image_index][i] = (y, x, label)
                    break  # Stop once we find and update the point

            # Mark the image as changed
            changed_masks.add(current_image_index)
            redraw_image_with_points(event.inaxes, highlight_idx=dragging_point)


def on_release(event):
    global dragging_point
    dragging_point = None

In [None]:
def reload_train_loader():
    global train_loader
    current_iteration = session_manager.get_current_iteration()
    previous_iteration_dir = dataset_manager.get_modified_mask_dir(current_iteration - 1)

    updated_train_masks = [
        os.path.join(previous_iteration_dir, filename_map[i])
        for i in range(len(train_imgs))
    ]

    train_dataset = CustomDataset(train_imgs, updated_train_masks, transform=transform, transform_label=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False, num_workers=0)

def update_overlay_alpha(change):
    """Update overlay transparency."""
    global overlay

    # Ensure overlay exists and is associated with the current image
    if overlay is not None:
        overlay.set_alpha(change['new'])  # Update transparency
        #fig.canvas.draw_idle()
        
def save_current_mask(delete_mode=False, img_id=None):
    """
    Saves the mask of the current image, ensuring all deletions and changes persist.
    This is called before switching images to maintain consistency.
    If `delete_mode` is True, it ensures deleted points are removed.
    """
    if img_id is None:
        img_id = current_image_index  # Default to current image if not specified

    current_iteration = session_manager.get_current_iteration()
    mask_path = os.path.join(dataset_manager.get_modified_mask_dir(current_iteration), filename_map[img_id])

    if os.path.exists(mask_path):
        mask = imageio.imread(mask_path)

        # ✅ Get existing selected points
        existing_points = {(y, x) for y, x, _ in selected_points[img_id]}
        modified = False  

        # ✅ Remove deleted points (Set them to unlabeled class 5)
        for y in range(mask.shape[0]):
            for x in range(mask.shape[1]):
                if (y, x) not in existing_points and mask[y, x] != 5:
                    mask[y, x] = 5  # Set deleted pixels to class 5
                    modified = True  

        # ✅ Save only if there were changes
        if modified or img_id in changed_masks:
            imageio.imsave(mask_path, mask)
            print(f"[INFO] Mask updated and saved for {mask_path}")

        # ✅ Ensure changes are tracked correctly
        changed_masks.discard(img_id)

    else:
        print(f"[WARNING] Mask file not found: {mask_path}. Skipping save.")


def on_forward_clicked(b):
    """
    Saves the mask of the current image before navigating to the next.
    Ensures that we can still move forward even if saving has issues.
    """
    global current_image_index
    save_current_mask()  # Save before switching

    # ✅ Ensure navigation happens
    next_index = (current_image_index + 1) % len(predictions_cache)

    # ✅ Debug print
    print(f"[DEBUG] Moving forward: {current_image_index} → {next_index}")

    current_image_index = next_index
    display_image(alpha=transparency_slider.value)

def on_backward_clicked(b):
    """
    Saves the mask of the current image before navigating to the previous.
    Ensures that we can still move backward even if saving has issues.
    """
    global current_image_index
    save_current_mask()  # Save before switching

    # ✅ Ensure navigation happens
    prev_index = (current_image_index - 1) % len(predictions_cache)

    # ✅ Debug print
    print(f"[DEBUG] Moving backward: {current_image_index} → {prev_index}")

    current_image_index = prev_index
    display_image(alpha=transparency_slider.value)

def on_continue_training_clicked(b):
    clear_output()
    current_iteration = session_manager.get_current_iteration()
    print(f"[INFO] Preparing iteration {current_iteration}...")

    # ✅ **Ensure all masks (including deletions) are saved before training**
    dataset_manager.save_all_masks(current_iteration)

    # ✅ **Increment iteration AFTER saving masks**
    session_manager.increment_iteration()

    # ✅ **Now it's safe to train**
    run_training_loop()

    # Reset UI elements
    images_labeled.clear()
    current_image_index = 0
    current_iteration_points.clear()
    display_image(alpha=transparency_slider.value)

In [None]:
def display_image(alpha=0.5):
    global overlay, current_image_index, fig, ax  # Define fig and ax as global to manage their instance

    # Clear previous output to prevent duplication
    clear_output(wait=True)

    # Create a new figure only if fig does not exist to prevent duplication
    if 'fig' not in globals() or 'ax' not in globals():
        fig, ax = plt.subplots(figsize=(7, 7))  # Create a new figure and axis

    ax.clear()  # Clear the axis content if reusing the figure

    # Check if the current index is valid
    if current_image_index < 0 or current_image_index >= len(predictions_cache):
        return

    # Fetch current image and prediction data
    img_data = predictions_cache[current_image_index]
    inp_unit = img_data["input"] / 255.0
    ax.imshow(inp_unit, cmap='gray', aspect='auto')  # Display the grayscale input image without dimensions

    # Convert multi-channel prediction (5, H, W) → (H, W) using argmax
    prediction_map = img_data["prediction"].argmax(axis=0)

    # Add overlay for predictions with the initial transparency setting
    overlay = ax.imshow(prediction_map, cmap='tab10', alpha=alpha, aspect='auto')

    # Remove axis spines and labels for a clean visualization
    ax.axis('off')

    # Adjust the layout before adding the title
    fig.tight_layout(pad=1.0)

    # Setup the title and other UI elements
    changed_image_count = len(changed_masks)
    points_collected_total = sum(len(points) for points in selected_points.values())
    current_iteration_points_count = sum(len(points) for points in current_iteration_points.values())
    fig.suptitle(f"Images changed: {changed_image_count}, "
                 f"Pts (current iteration): {current_iteration_points_count}, "
                 f"Pts (total): {points_collected_total}", fontsize=8)

    image_name_label.value = f"Image: {filename_map[current_image_index]}"
    labeled_checkbox.value = current_image_index in images_labeled

    # Plot labeled points
    plot_all_points(ax, current_image_index)

    # Update alpha for the overlay directly
    def update_overlay_alpha(change):
        overlay.set_alpha(change['new'])
        fig.canvas.draw_idle()  # Efficient redraw of the updated alpha only

    # Add the observer again
    transparency_slider.observe(update_overlay_alpha, names='value')

    # Set up only one set of event connections per call
    global event_connections
    for cid in event_connections:
        fig.canvas.mpl_disconnect(cid)
    event_connections = [
        fig.canvas.mpl_connect('button_press_event', on_click),
        fig.canvas.mpl_connect('button_release_event', on_release),
        fig.canvas.mpl_connect('motion_notify_event', on_motion)
    ]

    # Controls UI layout with improved order
    controls_box = VBox([
        # Layer 1: Class Selector (Horizontally aligned)
        HBox([label_selector], layout=Layout(justify_content='flex-start')),

        # Layer 2: Navigation buttons + Transparency slider
        HBox([
            backward_button,
            forward_button,
            transparency_slider
        ], layout=Layout(justify_content='flex-start')),

        # Layer 3: Image name & labeled checkbox
        HBox([
            image_name_label,
            labeled_checkbox
        ], layout=Layout(justify_content='flex-start')),

        # Layer 4: Continue Training button
        HBox([continue_training_button], layout=Layout(justify_content='flex-start'))
    ])

    # Explicitly display both the figure canvas and control elements
    display(VBox([fig.canvas, controls_box]))

    
def redraw_image_with_points(ax, highlight_idx=None):
    ax.clear()
    img_data = predictions_cache[current_image_index]

    # Display the input image
    ax.imshow(img_data["input"] / 255.0, cmap='gray', aspect='auto')

    # Convert multi-channel prediction (C, H, W) → (H, W) using argmax
    prediction_map = img_data["prediction"].argmax(axis=0)

    # Overlay the prediction mask with a categorical colormap
    overlay = ax.imshow(prediction_map, cmap='tab10', alpha=transparency_slider.value, aspect='auto')

    # Remove axis for a cleaner look
    ax.axis('off')

    # Define class colors
    class_colors = ['blue', 'green', 'yellow', 'red', 'pink']  # Adjust as needed for each class

    # Plot labeled points
    for i, (y, x, label) in enumerate(selected_points[current_image_index]):
        color = class_colors[label]  # Assign color based on class index
        size = 3 if i != highlight_idx else 5
        ax.plot(x, y, 'o', color=color, markersize=size)

    fig.canvas.draw_idle()

In [None]:
# Run Training Loop
def run_training_loop():
    global min_dice_loss
    current_iteration = session_manager.get_current_iteration()

    print(f"\n--- Active Learning Iteration {current_iteration} ---\n")

    # Load model from the previous iteration or initialize for the first iteration
    if current_iteration > 1:
        model_manager.load_model(current_iteration - 1)
    else:
        print("[INFO] Starting from scratch. Initializing model with default weights.")

    # Validate masks for the previous iteration
    previous_iteration_dir = dataset_manager.get_modified_mask_dir(current_iteration - 1)
    print(f"[INFO] Validating masks for iteration {current_iteration - 1}...")
    dataset_manager.validate_masks(current_iteration - 1)
    print(f"[INFO] All masks validated for iteration {current_iteration - 1}.")

    # Count pixels for the current training dataset
    class_counts = dataset_manager.count_class_pixels(previous_iteration_dir)

    # Print class pixel counts dynamically
    for i in range(5):
        print(f"[INFO] Class {i} pixels: {class_counts[i]}")


    # Reload train loader with masks from the previous iteration
    reload_train_loader()

    # Training loop
    iteration_min_loss = float('inf')
    for epoch in range(config["epochs_per_iteration"]):
        print(f"[INFO] Epoch {epoch + 1}/{config['epochs_per_iteration']} (Iteration {current_iteration})")
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # Save the best model for this iteration
        if valid_logs['cross_entropy_loss'] < iteration_min_loss:
            iteration_min_loss = valid_logs['cross_entropy_loss']
            model_manager.save_model(current_iteration)

        # Update the global best model if necessary
        if valid_logs['cross_entropy_loss'] < min_dice_loss:
            model_manager.save_model("best_model", is_best=True)
            min_dice_loss = valid_logs['cross_entropy_loss']

    # Explicitly load the best model for this iteration
    print("[INFO] Loading the best model for the current iteration...")
    model_manager.load_model(current_iteration)

    # Cache predictions for active learning using the best model
    global predictions_cache
    predictions_cache = model_manager.cache_predictions(train_loader, config["device"])
    print(f"[INFO] Cached predictions for {len(predictions_cache)} images.")

    # Update metrics history with class-wise pixel counts
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(config["device"]), labels.to(config["device"])
            predictions = model(images).argmax(dim=1)  # Get class indices from softmax
            y_true.extend(labels.cpu().numpy().flatten())
            y_pred.extend(predictions.cpu().numpy().flatten())

    # Compute multiclass metrics
    metrics_dict = session_manager.calculate_metrics_from_confusion_matrix(y_true, y_pred, num_classes=5)

    # Print per-class metrics
    for i in range(5):
        print(f"[INFO] Class {i} - Precision: {metrics_dict['Precision'][i]:.2f}%, "
              f"Recall: {metrics_dict['Recall'][i]:.2f}%, "
              f"F1: {metrics_dict['F1-Score'][i]:.2f}%, IoU: {metrics_dict['IoU'][i]:.2f}%")

    print(f"[INFO] Mean IoU: {metrics_dict['Mean IoU']:.2f}")

    # Store in history
    session_manager.update_metrics_history(metrics_dict)
    
    print("\nMetrics across iterations:")
    display(session_manager.metrics_history)
    
    # Display the first image of the new predictions
    display_image(alpha=transparency_slider.value)

In [None]:
# Initialization and Setup
dataset_manager = DatasetManager(config, train_imgs, train_masks, val_imgs, val_masks)
session_manager = SessionManager(config)
model_manager = ModelManager(config, model)

# Initialize selected points
dataset_manager.initialize_selected_points(selected_points)

# Observe transparency slider changes only once
transparency_slider.observe(update_overlay_alpha, names='value')

# Navigation handlers
forward_button.on_click(on_forward_clicked)
backward_button.on_click(on_backward_clicked)
continue_training_button.on_click(on_continue_training_clicked)

run_training_loop()