In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
print(torch.__version__)
from skimage.metrics import structural_similarity as ssim  # Import SSIM function
import os
import numpy as np
from glob import glob
import random
import imageio
from torchvision import transforms
from collections import defaultdict
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from shapely.geometry import Polygon
import sys
sys.path.append("D:/pytorch")  # Adds the parent directory to the system path
import segmentation_models_pytorch as smp
from ipywidgets import widgets, VBox, HBox, IntText, ToggleButtons, Button, IntSlider, Label, Checkbox, Layout
from IPython.display import display, clear_output
import pandas as pd
from sklearn.metrics import confusion_matrix
import shutil
import gc
import traceback

# Use the widget backend for better interactivity in Jupyter
%matplotlib widget

%config InlineBackend.figure_format = 'retina'
get_ipython().history_manager.enabled = False
%config InlineBackend.close_figures = True

!jupyter nbextension enable --py --sys-prefix ipympl
!jupyter nbextension enable --py --sys-prefix widgetsnbextension

1.11.0


Enabling notebook extension jupyter-matplotlib/extension...
      - Validating: ok
Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: ok


In [2]:
BASE_DIR = "D:/SPARSE/"
train_imgs = glob(os.path.join(BASE_DIR, "IMG/image_train/*tiff"))
train_masks = glob(os.path.join(BASE_DIR, "ROAD/initial_dataset/*png"))

val_imgs = glob(os.path.join(BASE_DIR, "ROAD/ROAD DATASET2/image_val/*tiff"))
val_masks = glob(os.path.join(BASE_DIR, "ROAD/ROAD DATASET2/class_val/*png"))

In [3]:
def pix_count(mask_images, ignore_index=None):
    pixel_counts = defaultdict(int)
    overflow_count = 0

    for mask_path in mask_images:
        mask_image = imageio.imread(mask_path)
        unique_values, unique_counts = np.unique(mask_image, return_counts=True)
        
        # Filter out the ignore_index if it is provided
        if ignore_index is not None:
            mask = unique_values != ignore_index
            unique_values = unique_values[mask]
            unique_counts = unique_counts[mask]
        
        for value, count in zip(unique_values, unique_counts):
            pixel_counts[value] += count

    # Determine number of classes based on the highest pixel value
    num_classes = max(pixel_counts.keys()) + 1

    # Check for and handle overflow (pixel values beyond num_classes)
    for key in list(pixel_counts.keys()):
        if key >= num_classes:
            overflow_count += pixel_counts[key]
            del pixel_counts[key]

    return dict(pixel_counts), num_classes, overflow_count

ignore_index_value = 2  # For example, to ignore white background
pixel_counts, num_classes, overflow_count = pix_count(mask_images=train_masks, ignore_index=2)

print(pixel_counts)
print("Number of classes:", num_classes)
print("Overflow count:", overflow_count)

{0: 1001, 1: 932}
Number of classes: 2
Overflow count: 0


In [4]:
class CustomDataset(data.Dataset):
    def __init__(self, image_paths, target_paths, transform=None, transform_label=None, is_validation=False):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transform = transform
        self.transform_label = transform_label
        self.is_validation = is_validation
        
    def __getitem__(self, index):
        image = imageio.imread(self.image_paths[index])
        image = np.asarray(image, dtype='float32')
        mask = imageio.imread(self.target_paths[index])
        mask = np.asarray(mask, dtype='int64')
        mask[mask > 2] = 2
        
        if self.is_validation:
            mask[mask == 2] = 0
        
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        
        if self.transform:
            image = self.transform(image)
            
        random.seed(seed)
        torch.manual_seed(seed)
        if self.transform_label:
            mask = self.transform_label(mask)
            mask = mask.squeeze(0)

        return image, mask

    def __len__(self):
        return len(self.image_paths)
    
# augmentations used in the process
transform = transforms.Compose([transforms.ToTensor()]) #,
                                #transforms.RandomHorizontalFlip(),
                                #transforms.RandomVerticalFlip()])


train_dataset = CustomDataset(train_imgs, train_masks, transform=transform, transform_label=transform)
val_dataset = CustomDataset(val_imgs, val_masks, transform=transforms.ToTensor(), transform_label=None, is_validation=True)
#test_dataset = CustomDataset(test_imgs, test_masks, transform=transforms.ToTensor(), transform_label=None, is_validation=True)

In [5]:
# Define model parameters
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['car']  # For binary segmentation, only one class (car) is needed
ACTIVATION = 'sigmoid' if len(CLASSES) == 1 else 'softmax'  # Use sigmoid for binary, softmax for multi-class
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available, else CPU

# Initialize the U-Net model with the specified encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES),
    activation=ACTIVATION,
    in_channels=3
)

# Get the preprocessing function for the chosen encoder
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [6]:
# Define training parameters
LEARNING_RATE = 0.001

# Choose the loss function. DiceLoss is commonly used for segmentation tasks.
d_loss = smp.utils.losses.DiceLoss(ignore_index=2)
loss = smp.utils.losses.DynamicWeightedConfidenceDiceLoss(
    eps=1.0,
    beta=1.0,
    amplification_factor=4.0,       # Amplify weights by a factor of 2 for confident incorrect predictions
    confidence_threshold=0.8,       # Threshold for high confidence
    correctness_threshold=0.5,      # Threshold for significant error
    activation="sigmoid",           # Use sigmoid activation for binary segmentation
    ignore_index=2                  # Ignore unlabeled pixels if specified
)

# Define the metric for evaluation. IoU (Intersection over Union) is a standard metric for segmentation.
metrics = [smp.utils.metrics.IoU()]

# Initialize the optimizer. Adam is a popular choice for deep learning tasks.
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=d_loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False, num_workers=0)
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=0)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=20, shuffle=False, num_workers=0)

### Training Loop

In this cell, we're executing the training loop for our segmentation model:

- **Epochs**: We're training the model for a total of `100` epochs. In each epoch, the model is trained on the entire training dataset and then validated on the validation dataset.

- **Monitoring Performance**: After each epoch, we print the Dice Loss for both training and validation to monitor the model's performance.

- **Model Saving**: If the validation performance (measured by Dice Loss) improves compared to previous epochs, we save the model's state dictionary. This way, we ensure that we retain the model weights that give the best performance on the validation set. The model is saved as `T1_Car2.pth`.

By the end of this loop, we aim to have a model that performs well on the validation dataset, indicating its potential to generalize well to unseen data.

In [7]:
# Configuration dictionary for paths and settings
BASE_DIR = "D:/SPARSE/"
config = {
    "base_dir": BASE_DIR,
    "data_folder": os.path.join(BASE_DIR, "SESSIONS/SESSIONS_ROAD/Session2"),
    "epochs_per_iteration": 50,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}


# Original dataset paths
train_imgs = glob(os.path.join(BASE_DIR, "IMG/image_train/*tiff"))
train_masks = glob(os.path.join(BASE_DIR, "ROAD/initial_dataset/*png"))
val_imgs = glob(os.path.join(BASE_DIR, "ROAD/ROAD DATASET2/image_val/*tiff"))
val_masks = glob(os.path.join(BASE_DIR, "ROAD/ROAD DATASET2/class_val/*png"))
filename_map = {i: os.path.basename(mask) for i, mask in enumerate(train_masks)}

# Set up paths within the main session data folder for iteration storage
config["dataset_folder"] = os.path.join(config["data_folder"], "datasets")
config["model_folder"] = os.path.join(config["data_folder"], "models")
config["metrics_history_path"] = os.path.join(config["data_folder"], "metrics_history.csv")

# Ensure directory structure is created if it doesn't exist
os.makedirs(config["dataset_folder"], exist_ok=True)
os.makedirs(config["model_folder"], exist_ok=True)

# Widget controls setup
transparency_slider = widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='Transparency:', continuous_update=True)
foreground_count = widgets.IntText(value=0, description="Foreground Points:", disabled=True)
background_count = widgets.IntText(value=0, description="Background Points:", disabled=True)
label_selector = ToggleButtons(options=['Foreground', 'Background'])
forward_button = Button(description=">")
backward_button = Button(description="<")
continue_training_button = Button(description="Continue Training", layout=Layout(margin='20px 0px 0px 0px'))
epoch_slider = IntSlider(value=config["epochs_per_iteration"], min=1, max=100, description="Epochs:")
prediction_value_label = Label(value="Prediction: ")
coordinates_label = Label(value="Coordinates: ")
image_name_label = Label(value="Image: ")
labeled_checkbox = Checkbox(description="Labeled", value=False)

# Initialize data structures
predictions_cache = {}
selected_points = defaultdict(list)
changed_masks = set()
images_labeled = set()
current_image_index = 0
current_iteration_points = defaultdict(list)
min_dice_loss = float('inf')
points_collected_total = 0
dragging_point = None
drag_start = None
event_connections = []

#fig, ax = plt.subplots(figsize=(7, 7))  # Default figure setup
overlay = None

In [8]:
# DatasetManager
class DatasetManager:
    def __init__(self, config, train_imgs, train_masks, val_imgs, val_masks):
        self.config = config
        self.train_imgs = train_imgs
        self.train_masks = train_masks
        self.val_imgs = val_imgs
        self.val_masks = val_masks
        self.session_manager = SessionManager(config)

        # Determine the dataset for initialization
        latest_iteration = self.session_manager.get_current_iteration() - 1
        if latest_iteration == 0:
            self.modified_mask_dir = self.get_modified_mask_dir(0)
            self._setup_initial_masks()
        else:
            self.modified_mask_dir = self.get_modified_mask_dir(latest_iteration)

    def validate_masks(self, iteration):
        """Ensure all masks for the specified iteration exist."""
        iteration_dir = self.get_modified_mask_dir(iteration)
        for img_id, mask_name in filename_map.items():
            mask_path = os.path.join(iteration_dir, mask_name)
            if not os.path.exists(mask_path):
                raise FileNotFoundError(f"Missing mask for img_id {img_id} in iteration {iteration}")


    def initialize_selected_points(self, selected_points):
        """Load labeled points from existing masks."""
        latest_iteration_dir = self.modified_mask_dir
        print(f"[INFO] Initializing selected points from masks in {latest_iteration_dir}...")
        for img_id, mask_name in filename_map.items():
            mask_path = os.path.join(latest_iteration_dir, mask_name)
            if os.path.exists(mask_path):
                mask = imageio.imread(mask_path)
                labeled_positions = np.argwhere((mask == 0) | (mask == 1))  # Assuming binary masks
                labels = mask[labeled_positions[:, 0], labeled_positions[:, 1]]
                selected_points[img_id].extend(
                    [(y, x, label) for (y, x), label in zip(labeled_positions, labels)]
                )
        print(f"[INFO] Loaded selected points for {len(selected_points)} images.")

    def get_modified_mask_dir(self, iteration):
        """Return the directory for a specific iteration's masks."""
        return os.path.join(self.config["dataset_folder"], f"iteration_{iteration}")
    
    def _setup_initial_masks(self):
        """Set up initial masks in iteration_0."""
        os.makedirs(self.modified_mask_dir, exist_ok=True)
        for mask in self.train_masks:
            shutil.copy(mask, os.path.join(self.modified_mask_dir, os.path.basename(mask)))

    def save_all_masks(self, selected_points, changed_masks, current_iteration):
        """
        Save all masks (modified and unmodified) for the current iteration.
        Unlabeled pixels are set to the value of 2.
        """
        iteration_dir = self.get_modified_mask_dir(current_iteration)
        previous_iteration_dir = self.get_modified_mask_dir(current_iteration - 1)
        os.makedirs(iteration_dir, exist_ok=True)

        modified_count = 0  # Counter for modified masks

        for img_id, mask_name in filename_map.items():
            previous_mask_path = os.path.join(previous_iteration_dir, mask_name)
            current_mask_path = os.path.join(iteration_dir, mask_name)

            # Copy unmodified masks directly
            if img_id not in changed_masks:
                shutil.copy(previous_mask_path, current_mask_path)
                continue

            # Overwrite with modified masks for images in `changed_masks`
            mask = imageio.imread(previous_mask_path)
            mask.fill(2)  # Reset the mask to `2` for unlabeled pixels

            # Apply the labeled points from `selected_points`
            for y, x, label in selected_points[img_id]:
                mask[y, x] = label

            # Save the updated mask
            imageio.imsave(current_mask_path, mask)
            modified_count += 1

        print(f"Total modified masks saved in iteration {current_iteration}: {modified_count}")
        print(f"Total masks (including unmodified) saved in iteration {current_iteration}: {len(filename_map)}")

        # Clear `changed_masks` for the next iteration
        changed_masks.clear()

    def count_foreground_background_pixels(self, iteration_dir):
        """Count foreground and background pixels in the masks of the given iteration."""
        fg_count, bg_count = 0, 0
        for mask_file in os.listdir(iteration_dir):
            mask_path = os.path.join(iteration_dir, mask_file)
            mask = imageio.imread(mask_path)
            fg_count += np.sum(mask == 1)
            bg_count += np.sum(mask == 0)
        return fg_count, bg_count

# SessionManager
class SessionManager:
    def __init__(self, config):
        self.config = config
        self.current_iteration = self._get_latest_iteration() + 1  # Start from the next iteration
        self.metrics_history = (
            pd.read_csv(config["metrics_history_path"])
            if os.path.exists(config["metrics_history_path"])
            else pd.DataFrame(columns=[
                'Iteration', 'TP', 'FP', 'FN', 'TN', 
                'Precision', 'Recall', 'F-Score', 
                'IoU', 'Foreground Pixels', 'Background Pixels'
            ])
        )

    def _get_latest_iteration(self):
        """Determine the latest iteration based on available datasets."""
        dataset_folder = self.config["dataset_folder"]
        iteration_folders = [
            int(folder.split("_")[-1])
            for folder in os.listdir(dataset_folder)
            if folder.startswith("iteration_") and folder.split("_")[-1].isdigit()
        ]
        return max(iteration_folders, default=0)  # Default to 0 if no iterations exist

    def increment_iteration(self):
        """Increment the iteration counter."""
        self.current_iteration += 1

    def get_current_iteration(self):
        """Return the current iteration."""
        return self.current_iteration

    def calculate_metrics_from_confusion_matrix(self, y_true, y_pred):
        """Calculate evaluation metrics from a confusion matrix."""
        from sklearn.metrics import confusion_matrix
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        precision = tp / (tp + fp) * 100 if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) * 100 if (tp + fn) > 0 else 0
        f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        iou = tp / (tp + fp + fn) * 100 if (tp + fp + fn) > 0 else 0
        return int(tp), int(fp), int(fn), int(tn), round(precision, 2), round(recall, 2), round(f1_score, 2), round(iou, 2)

    def update_metrics_history(self, tp, fp, fn, tn, precision, recall, f1_score, iou, fg_count, bg_count):
        """Update metrics history and save it to a CSV file."""
        new_metrics = pd.DataFrame([{
            'Iteration': self.current_iteration,
            'TP': tp,
            'FP': fp,
            'FN': fn,
            'TN': tn,
            'Precision': precision,
            'Recall': recall,
            'F-Score': f1_score,
            'IoU': iou,
            'Foreground Pixels': fg_count,
            'Background Pixels': bg_count
        }])

        # Use pd.concat for better performance
        self.metrics_history = pd.concat([self.metrics_history, new_metrics], ignore_index=True)

        # Save to disk
        self.metrics_history.to_csv(self.config["metrics_history_path"], index=False)
        print("[INFO] Metrics history updated and saved.")


# ModelManager
class ModelManager:
    def __init__(self, config, model):
        self.config = config
        self.model = model

        
    def save_model(self, iteration, is_best=False):
        """Save the model for a specific iteration."""
        if iteration == "best_model":
            model_path = os.path.join(self.config["model_folder"], "best_model.pth")
        else:
            model_path = os.path.join(self.config["model_folder"], f"iteration_{iteration}.pth")

        torch.save(self.model.state_dict(), model_path)
        if is_best:
            print(f"[INFO] Updated global best model at: {model_path}")
        else:
            print(f"[INFO] Model saved for iteration {iteration} at: {model_path}")

    def load_model(self, iteration):
        """Load the model for a specific iteration."""
        model_path = os.path.join(self.config["model_folder"], f"iteration_{iteration}.pth")
        if os.path.exists(model_path):
            print(f"[INFO] Loading model from: {model_path}")
            self.model.load_state_dict(torch.load(model_path))
        else:
            print(f"[ERROR] Model not found for iteration {iteration}: {model_path}")
            raise FileNotFoundError(f"Model not found: {model_path}")

    def cache_predictions(self, train_loader, device="cpu"):
        """Cache predictions for active learning steps."""
        #print(f"[DEBUG] Caching predictions using the current model...")
        self.model.eval()  # Set the model to evaluation mode

        predictions_cache = {}
        with torch.no_grad():
            for batch_idx, (inputs, _) in enumerate(train_loader):
                inputs = inputs.to(device)
                outputs = self.model(inputs).cpu().numpy()
                for i, output in enumerate(outputs):
                    img_id = batch_idx * train_loader.batch_size + i
                    predictions_cache[img_id] = {
                        "input": np.transpose(inputs[i].cpu().numpy(), (1, 2, 0)),
                        "prediction": output.squeeze()
                    }
        print(f"[INFO] Cached predictions for {len(predictions_cache)} images.")
        return predictions_cache

In [9]:
def update_counts(img_id):
    fg_count = sum(1 for _, _, lbl in selected_points[img_id] if lbl == 1)
    bg_count = sum(1 for _, _, lbl in selected_points[img_id] if lbl == 0)
    foreground_count.value = fg_count
    background_count.value = bg_count

def find_nearest_point(img_id, x, y, threshold=2):
    for i, (py, px, _) in enumerate(selected_points[img_id]):
        if abs(px - x) < threshold and abs(py - y) < threshold:
            return i
    return None

def delete_nearest_point(img_id, x, y):
    """
    Deletes the nearest point in the selected_points for a given image and coordinates.
    Updates the associated data structures to ensure the deletion is reflected in masks.
    """
    # Find the nearest point in selected_points
    point_idx = find_nearest_point(img_id, x, y)
    if point_idx is not None:
        # Remove the point from selected_points
        deleted_point = selected_points[img_id].pop(point_idx)
        
        # Update current_iteration_points by removing the corresponding point
        current_iteration_points[img_id] = [
            pt for pt in current_iteration_points[img_id] 
            if not (pt[0] == deleted_point[0] and pt[1] == deleted_point[1])]
        
        # Mark the image as changed
        changed_masks.add(img_id)

In [10]:
def plot_all_points(ax, img_id):
    for y, x, label in selected_points[img_id]:
        color = 'lime' if label == 1 else 'red'
        ax.plot(x, y, 'o', color=color, markersize=3)

# Event handlers for adding, dragging, and highlighting points
def on_click(event):
    global dragging_point
    if event.inaxes:
        x, y = int(event.xdata), int(event.ydata)
        nearest_point_idx = find_nearest_point(current_image_index, x, y)

        if nearest_point_idx is not None and event.button == 1:  # Left-click to start dragging
            dragging_point = nearest_point_idx
        elif event.button == 3:  # Right-click to delete the point
            delete_nearest_point(current_image_index, x, y)
            redraw_image_with_points(event.inaxes)
        else:
            label = 1 if label_selector.value == 'Foreground' else 0
            selected_points[current_image_index].append((y, x, label))
            current_iteration_points[current_image_index].append((y, x, label))
            changed_masks.add(current_image_index)
            update_counts(current_image_index)
            images_labeled.add(current_image_index)
            labeled_checkbox.value = True
            redraw_image_with_points(event.inaxes)

def on_motion(event):
    global dragging_point
    if event.inaxes:
        x, y = int(event.xdata), int(event.ydata)
        prediction_value_label.value = f"Prediction: {predictions_cache[current_image_index]['prediction'][y, x]:.2f}"
        coordinates_label.value = f"Coordinates: ({x}, {y})"
        
        nearest_point_idx = find_nearest_point(current_image_index, x, y)
        if nearest_point_idx is not None and dragging_point is None:
            redraw_image_with_points(event.inaxes, highlight_idx=nearest_point_idx)
        elif dragging_point is None:
            redraw_image_with_points(event.inaxes)

        if dragging_point is not None:
            selected_points[current_image_index][dragging_point] = (y, x, selected_points[current_image_index][dragging_point][2])
            redraw_image_with_points(event.inaxes, highlight_idx=dragging_point)

def on_release(event):
    global dragging_point
    dragging_point = None

In [11]:
def reload_train_loader():
    global train_loader
    current_iteration = session_manager.get_current_iteration()
    previous_iteration_dir = dataset_manager.get_modified_mask_dir(current_iteration - 1)

    updated_train_masks = [
        os.path.join(previous_iteration_dir, filename_map[i])
        for i in range(len(train_imgs))
]
    train_dataset = CustomDataset(train_imgs, updated_train_masks, transform=transform, transform_label=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False, num_workers=0)

def update_overlay_alpha(change):
    """Update overlay transparency."""
    global overlay

    # Ensure overlay exists and is associated with the current image
    if overlay is not None:
        overlay.set_alpha(change['new'])  # Update transparency
        #fig.canvas.draw_idle()
        
# Navigation handlers
def on_forward_clicked(b):
    global current_image_index
    forward_button.disabled = True
    current_image_index = (current_image_index + 1) % len(predictions_cache)
    display_image(alpha=transparency_slider.value)
    forward_button.disabled = False

def on_backward_clicked(b):
    global current_image_index
    backward_button.disabled = True
    current_image_index = (current_image_index - 1) % len(predictions_cache)
    display_image(alpha=transparency_slider.value)
    backward_button.disabled = False
        
# On "Continue Training" Click
def on_continue_training_clicked(b):
    clear_output()
    current_iteration = session_manager.get_current_iteration()
    print(f"[INFO] Saving masks for iteration {current_iteration}...")
    dataset_manager.save_all_masks(selected_points, changed_masks, current_iteration)

    # Increment the iteration count *after* saving masks
    session_manager.increment_iteration()

    # Run training loop
    run_training_loop()

    # Reset UI elements
    images_labeled.clear()
    current_image_index = 0
    current_iteration_points.clear()
    display_image(alpha=transparency_slider.value)

In [12]:
def display_image(alpha=0.5):
    global overlay, current_image_index, fig, ax  # Define fig and ax as global to manage their instance

    # Clear previous output to prevent duplication
    clear_output(wait=True)

    # Create a new figure only if fig does not exist to prevent duplication
    if 'fig' not in globals() or 'ax' not in globals():
        fig, ax = plt.subplots(figsize=(7, 7))  # Create a new figure and axis

    ax.clear()  # Clear the axis content if reusing the figure

    # Check if the current index is valid
    if current_image_index < 0 or current_image_index >= len(predictions_cache):
        return

    # Fetch current image and prediction data
    img_data = predictions_cache[current_image_index]
    inp_unit = img_data["input"] / 255.0
    ax.imshow(inp_unit, cmap='gray', aspect='auto')  # Display the grayscale input image without dimensions

    # Add overlay for predictions with the initial transparency setting
    overlay = ax.imshow(img_data["prediction"], cmap='Reds', alpha=alpha, aspect='auto')

    # Remove axis spines and labels for a clean visualization
    ax.axis('off')

    # Adjust the layout before adding the title
    fig.tight_layout(pad=1.0)

    # Setup the title and other UI elements
    changed_image_count = len(changed_masks)
    points_collected_total = sum(len(points) for points in selected_points.values())
    current_iteration_points_count = sum(len(points) for points in current_iteration_points.values())
    fig.suptitle(f"Images changed: {changed_image_count}, "
                 f"Pts (current iteration): {current_iteration_points_count}, "
                 f"Pts (total): {points_collected_total}", fontsize=8)

    image_name_label.value = f"Image: {filename_map[current_image_index]}"
    labeled_checkbox.value = current_image_index in images_labeled

    # Plot labeled points
    plot_all_points(ax, current_image_index)

    # Update alpha for the overlay directly
    def update_overlay_alpha(change):
        overlay.set_alpha(change['new'])
        fig.canvas.draw_idle()  # Efficient redraw of the updated alpha only

    # Add the observer again
    transparency_slider.observe(update_overlay_alpha, names='value')

    # Set up only one set of event connections per call
    global event_connections
    for cid in event_connections:
        fig.canvas.mpl_disconnect(cid)
    event_connections = [
        fig.canvas.mpl_connect('button_press_event', on_click),
        fig.canvas.mpl_connect('button_release_event', on_release),
        fig.canvas.mpl_connect('motion_notify_event', on_motion)
    ]

    # Controls UI layout

    controls_box = VBox([
        HBox([
            backward_button,    # < button
            forward_button,     # > button
            label_selector      # Foreground/Background
        ]),
        HBox([
            transparency_slider,  # Transparency slider
            image_name_label,     # Image number label
            labeled_checkbox      # Labeled checkbox
        ]),
        HBox([
            continue_training_button  # Continue Training button
        ])
    ])

    # Explicitly display both the figure canvas and control elements
    display(VBox([fig.canvas, controls_box]))

def redraw_image_with_points(ax, highlight_idx=None):
    ax.clear()  # Clear the axis content
    img_data = predictions_cache[current_image_index]
    ax.imshow(img_data["input"] / 255.0, cmap='gray', aspect='auto')

    # Update existing overlay
    if overlay:
        overlay.set_data(img_data["prediction"])  # Update data
        overlay.set_alpha(transparency_slider.value)
    else:
        overlay = ax.imshow(img_data["prediction"], cmap='Reds', alpha=transparency_slider.value, aspect='auto')

    ax.axis('off')  # Remove axis spines and labels

    # Plot labeled points
    for i, (y, x, label) in enumerate(selected_points[current_image_index]):
        color = 'lime' if label == 1 else 'red'
        size = 2 if i != highlight_idx else 4
        ax.plot(x, y, 'o', color=color if i != highlight_idx else 'yellow', markersize=size)

    fig.canvas.draw_idle()  # Efficient redraw

    
def redraw_image_with_points(ax, highlight_idx=None):
    ax.clear()
    img_data = predictions_cache[current_image_index]
    ax.imshow(img_data["input"] / 255.0, cmap='gray', aspect='auto')
    overlay = ax.imshow(img_data["prediction"], cmap='Reds', alpha=transparency_slider.value, aspect='auto')

    # Remove axis for a cleaner look
    ax.axis('off')

    for i, (y, x, label) in enumerate(selected_points[current_image_index]):
        color = 'lime' if label == 1 else 'red'
        size = 2 if i != highlight_idx else 4
        ax.plot(x, y, 'o', color=color if i != highlight_idx else 'yellow', markersize=size)

    fig.canvas.draw_idle()

In [13]:
# Run Training Loop
def run_training_loop():
    global min_dice_loss
    current_iteration = session_manager.get_current_iteration()

    print(f"\n--- Active Learning Iteration {current_iteration} ---\n")

    # Load model from the previous iteration or initialize for the first iteration
    if current_iteration > 1:
        model_manager.load_model(current_iteration - 1)
    else:
        print("[INFO] Starting from scratch. Initializing model with default weights.")

    # Validate masks for the previous iteration
    previous_iteration_dir = dataset_manager.get_modified_mask_dir(current_iteration - 1)
    print(f"[INFO] Validating masks for iteration {current_iteration - 1}...")
    dataset_manager.validate_masks(current_iteration - 1)
    print(f"[INFO] All masks validated for iteration {current_iteration - 1}.")

    # Count pixels for the current training dataset
    fg_count, bg_count = dataset_manager.count_foreground_background_pixels(previous_iteration_dir)
    print(f"[INFO] Foreground pixels: {fg_count}, Background pixels: {bg_count}")

    # Reload train loader with masks from the previous iteration
    reload_train_loader()

    # Training loop
    iteration_min_loss = float('inf')
    for epoch in range(config["epochs_per_iteration"]):
        print(f"[INFO] Epoch {epoch + 1}/{config['epochs_per_iteration']} (Iteration {current_iteration})")
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # Save the best model for this iteration
        if valid_logs['dice_loss'] < iteration_min_loss:
            iteration_min_loss = valid_logs['dice_loss']
            model_manager.save_model(current_iteration)

        # Update the global best model if necessary
        if valid_logs['dice_loss'] < min_dice_loss:
            model_manager.save_model("best_model", is_best=True)
            min_dice_loss = valid_logs['dice_loss']

    # Explicitly load the best model for this iteration
    print("[INFO] Loading the best model for the current iteration...")
    model_manager.load_model(current_iteration)

    # Cache predictions for active learning using the best model
    global predictions_cache
    predictions_cache = model_manager.cache_predictions(train_loader, config["device"])
    print(f"[INFO] Cached predictions for {len(predictions_cache)} images.")

    # Update metrics history with foreground and background pixel counts
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(config["device"]), labels.to(config["device"])
            predictions = model(images).round()
            y_true.extend(labels.cpu().numpy().flatten())
            y_pred.extend(predictions.cpu().numpy().flatten())
    tp, fp, fn, tn, precision, recall, f1_score, iou = session_manager.calculate_metrics_from_confusion_matrix(y_true, y_pred)
    session_manager.update_metrics_history(tp, fp, fn, tn, precision, recall, f1_score, iou, fg_count, bg_count)
    print(f"[INFO] Metrics updated for iteration {current_iteration}: "
          f"Precision={precision:.2f}%, Recall={recall:.2f}%, F1={f1_score:.2f}%, IoU={iou:.2f}%")   
    
    print("\nMetrics across iterations:")
    display(session_manager.metrics_history)
    
    # Display the first image of the new predictions
    display_image(alpha=transparency_slider.value)

In [14]:
# Initialization and Setup
dataset_manager = DatasetManager(config, train_imgs, train_masks, val_imgs, val_masks)
session_manager = SessionManager(config)
model_manager = ModelManager(config, model)

# Initialize selected points
dataset_manager.initialize_selected_points(selected_points)

# Observe transparency slider changes only once
transparency_slider.observe(update_overlay_alpha, names='value')

# Navigation handlers
forward_button.on_click(on_forward_clicked)
backward_button.on_click(on_backward_clicked)
continue_training_button.on_click(on_continue_training_clicked)

run_training_loop()

VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…