## 1. Import library, Set up seed and GPU

In [9]:
import os
import json
import random
import numpy as np
import torch
import torchvision
from torchvision import transforms, models
from torchvision.ops import box_iou
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
from tqdm import tqdm
import cv2

In [10]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

set_seed()

In [11]:
# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'current working device : {device}')

current working device : cpu


## 2. Download Dataset for segmentation

In [12]:
if not os.path.exists("seg_dataset"):
    # !pip install gdown
    !gdown https://drive.google.com/uc?id=1AbU0ghVIxP81SR8_L17wtzdZwSvZvRZl
    !unzip -qq seg_dataset.zip

zsh:1: no matches found: https://drive.google.com/uc?id=1AbU0ghVIxP81SR8_L17wtzdZwSvZvRZl
unzip:  cannot find or open seg_dataset.zip, seg_dataset.zip.zip or seg_dataset.zip.ZIP.


## 3. Define a class for the dataset

In [None]:
# Define the class for custom dataset
class SegmentationDataset(Dataset):
    """
    Segmentation Dataset Class
    - images_dir: literally it's a directory that images are located
    - labels_dir: it's a directory that labels (JSON' files) are located
    - image_transform: image transformation function
    - mask_transform: mask transformation function
    """

    # Initialize directories of images, labels, transformations
    def __init__(self, images_dir, labels_dir, image_transform=None, mask_transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.image_files = sorted(os.listdir(images_dir))
        self.label_files = sorted(os.listdir(labels_dir))

        # Check whether the name of the image and label are matched or not
        assert len(self.image_files) == len(self.label_files), "Number of image and label files do not match."
        for img_file, lbl_file in zip(self.image_files, self.label_files):
            assert os.path.splitext(img_file)[0] == os.path.splitext(lbl_file)[0], f"Image and label filename's do not match: {img_file}, {lbl_file}"

    # Return length of the dataset
    def __len__(self):
        return len(self.image_files)

    # Return the image and mask for the given index
    def __getitem__(self, idx):
        # Set up paths for image and label files
        image_path = os.path.join(self.images_dir, self.image_files[idx])
        label_path = os.path.join(self.labels_dir, self.label_files[idx])

        # Load image, then convert it to RGB
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error occured while loading your image: {image_path}\n{e}")
            raise

        # Generate a Mask
        mask = self.create_mask(label_path, image.size)

        # Apply transformation to each image and mask
        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        return image, mask

    # Create the mask from label (.JSON) file that has polygons
    def create_mask(self, label_path, image_size):
        # Create an empty mask image
        mask = Image.new('L', image_size, 0)
        with open(label_path, 'r') as f:
            try:
                label_data = json.load(f)
            except json.JSONDecodeError:
                print(f"Failed to parse JSON file at : {label_path}")
                return mask  # Return empty mask

        draw = ImageDraw.Draw(mask)

        # Get data(object information) by using 'segmentaion' key value
        shapes = label_data.get('segmentation', None)

        if shapes is None:
            print(f"There is no data from the label with key of 'segmentation': {label_path}")
            return mask  # Return empty mask

        for shape in shapes:
            # Bring coordinates of polygon
            points = shape.get('polygon', None)
            if points is None:
                print(f"객체에 'polygon' 키가 없습니다: {label_path}")
                continue 

            # Convert the polygon coornidates into integers
            points = [tuple(map(int, point)) for point in points]

            # Get the label
            label = shape.get('name', None)
            if label == 'Parking Space':
                draw.polygon(points, outline=1, fill=1)
            elif label == 'Driveable Space' or label == 'Drivable Space':
                draw.polygon(points, outline=2, fill=2)
            else:
                print(f"Unknown label '{label}' is found at : {label_path}")
                
        return mask

In [None]:
# Define the MaskRCNNDataset class
class MaskRCNNDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, mask = self.dataset[idx]

        # Initialize masks and boxes
        masks = []
        boxes = []
        labels = []

        for class_id in [1, 2]:
            class_mask = (mask == class_id)
            if class_mask.sum() > 0:
                num_labels, labels_img = cv2.connectedComponents(
                    class_mask.numpy().astype(np.uint8))

                for label_id in range(1, num_labels):
                    instance_mask = labels_img == label_id
                    if instance_mask.sum() > 50:
                        instance_mask = torch.from_numpy(instance_mask).float()

                        pos = torch.where(instance_mask)
                        if len(pos[0]) > 0 and len(pos[1]) > 0:
                            xmin = float(pos[1].min())
                            xmax = float(pos[1].max())
                            ymin = float(pos[0].min())
                            ymax = float(pos[0].max())

                            if xmax > xmin and ymax > ymin:
                                masks.append(instance_mask)
                                boxes.append([xmin, ymin, xmax, ymax])
                                labels.append(class_id)

        # Handle empty predictions
        if not masks:
            boxes = torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float32)
            labels = torch.tensor([0], dtype=torch.int64)
            masks = torch.zeros((1, mask.shape[0], mask.shape[1]), dtype=torch.float32)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)
            masks = torch.stack(masks)

        # Normalize the image
        if isinstance(image, torch.Tensor) and image.max() > 1.0:
            image = image / 255.0

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks
        }

        return image, target


## 4. Data Transformation and Data Loader (loading data during model training and evaluation)

In [None]:
# Define data transformations
# Transformations for semantic segmentation
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the image to 256x256
    transforms.ToTensor(),          # Convert the image to a tensor
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),  # Resize the mask with nearest-neighbor interpolation
    transforms.PILToTensor(),                                   # Convert the mask to a tensor
    transforms.Lambda(lambda x: x.squeeze().long()),            # Remove extra dimensions and convert to long type
])


In [None]:
# Transform the image for Mask R-CNN (to keep the same image size)
maskrcnn_image_transform = transforms.Compose([
    transforms.ToTensor(),
])

maskrcnn_mask_transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.squeeze().long()),
])

In [None]:
# Set up directories for each training, testing, and validation 
train_images_dir = 'seg_dataset/train/images/'
train_labels_dir = 'seg_dataset/train/labels/'

val_images_dir = 'seg_dataset/validation/images/'
val_labels_dir = 'seg_dataset/validation/labels/'

test_images_dir = 'seg_dataset/test/images/'
test_labels_dir = 'seg_dataset/test/labels/'

In [None]:
# Create dataset for semantic segmentaion
train_dataset = SegmentationDataset(
    train_images_dir, train_labels_dir,
    image_transform=image_transform,
    mask_transform=mask_transform
)

val_dataset = SegmentationDataset(
    val_images_dir, val_labels_dir,
    image_transform=image_transform,
    mask_transform=mask_transform
)

test_dataset = SegmentationDataset(
    test_images_dir, test_labels_dir,
    image_transform=image_transform,
    mask_transform=mask_transform
)

In [None]:
# Create dataset for Mask-RCNN
maskrcnn_train_dataset = SegmentationDataset(
    train_images_dir, train_labels_dir,
    image_transform=maskrcnn_image_transform,
    mask_transform=maskrcnn_mask_transform
)

maskrcnn_val_dataset = SegmentationDataset(
    val_images_dir, val_labels_dir,
    image_transform=maskrcnn_image_transform,
    mask_transform=maskrcnn_mask_transform
)

maskrcnn_test_dataset = SegmentationDataset(
    test_images_dir, test_labels_dir,
    image_transform=maskrcnn_image_transform,
    mask_transform=maskrcnn_mask_transform
)


In [None]:
# Create data loader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

In [None]:
# Function of data loader
def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=2):
    train_loader = DataLoader(
        MaskRCNNDataset(train_dataset),     # get train dateset
        batch_size = batch_size,            # set upi batch size for model training
        shuffle = True,                     # shuffle them during model training
        num_workers = 4,                    # num of worker threads to load dataset
        pin_memory = True,                  # for faster GPU data transfer
        collate_fn = lambda x: tuple(zip(*x))   # custom collation function to batch data
    )

    val_loader = DataLoader(
        MaskRCNNDataset(val_dataset),
        batch_size = 1,
        shuffle = False,
        num_workers = 4,
        pin_memory = True,
        collate_fn = lambda x: tuple(zip(*x))
    )

    test_loader = DataLoader(
        MaskRCNNDataset(test_dataset),
        batch_size = 1,
        shuffle = False,
        num_workers = 4,
        pin_memory = True,
        collate_fn = lambda x: tuple(zip(*x))
    )

    return train_loader, val_loader, test_loader

In [None]:
# Calling 'create_dataloaders' in order to generate data loader for Mask RCNN
maskrcnn_train_loader, maskrcnn_val_loader, maskrcnn_test_loader = create_dataloaders(
    maskrcnn_train_dataset, maskrcnn_val_dataset, maskrcnn_test_dataset, batch_size=2
)

## 5. Visualizing Sample Datasets 

In [None]:
# Dataset visualization function
def visualize_dataset_sample(dataset, num_samples=5):
    """Visualize samples from the dataset."""
    indices = random.sample(range(len(dataset)), num_samples)  # Randomly select sample indices

    for idx in indices:
        image, mask = dataset[idx]
        image_np = image.permute(1, 2, 0).numpy()  # Convert image to NumPy array (HWC format)
        mask_np = mask.numpy()  # Convert mask to NumPy array

        # Colorize the mask
        mask_color = np.zeros((mask_np.shape[0], mask_np.shape[1], 3))  # Initialize color mask
        mask_color[mask_np == 1] = [0, 1, 0]  # Parking space: Green
        mask_color[mask_np == 2] = [0, 0, 1]  # Drivable space: Blue

        # Create an overlay
        overlay = 0.5 * image_np + 0.5 * mask_color  # Blend image and mask
        overlay = np.clip(overlay, 0, 1)  # Ensure values are within [0, 1]

        # Visualization
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(image_np)  # Show original image
        plt.title('Original')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(mask_np, cmap='gray')  # Show mask
        plt.title('Mask')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(overlay)  # Show overlay
        plt.title('Overlay')
        plt.axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
# Check dataset samples
print("Visualizing dataset samples:")
visualize_dataset_sample(train_dataset, num_samples = 3)

## 6. Model Definition

In [None]:
# Mask R-CNN model definition
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_maskrcnn_model(num_classes):
    # Load a pretrained Mask R-CNN model with a ResNet50 backbone and FPN
    model = models.detection.maskrcnn_resnet50_fpn(
        pretrained=True,            # Use pretrained weights
        box_detections_per_img = 100, # Number of object detections per image
        min_size = 800,               # Min image size
        max_size = 1333               # Max image size
    )

    # Modify the box predictor (for bounding boxes)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Mask prediction (for segmentation masks)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256  # size of hidden layer
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

## 7. Training Log

In [None]:
class AverageMeter:
    """Calculate an avg and current valuees and Save it"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## 8. Training Function

In [None]:
# Mask R-CNN Training Function
def train_maskrcnn(model, dataloaders, optimizer, num_epochs, device):
    scaler = torch.cuda.amp.GradScaler()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=2, factor=0.5, verbose=True
    )  # Reduce learning rate

    best_loss = float('inf')
    best_model_wts = None
    history = {'train_loss': [], 'val_loss': []}  # Store losses for each epoch

    for epoch in range(num_epochs):
        print(f"\nMask R-CNN - Epoch {epoch+1}/{num_epochs}")
        print('-' * 20)

        # Iterate through training and validation phases
        for phase in ['train', 'val']:
            running_loss = 0.0
            batch_count = 0

            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.train()  # Set model to training mode (fixing BatchNorm issue below)
                # Set BatchNorm layers to evaluation mode during validation
                for m in model.modules():
                    if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
                        m.eval()

            # Loop through batches
            for images, targets in tqdm(dataloaders[phase]):
                images = [img.to(device) for img in images]  # Move images to the device (GPU)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  # Move targets to the device

                optimizer.zero_grad()  #Clear gradients

                with torch.set_grad_enabled(phase == 'train'):  # Enable gradients for training
                    if phase == 'train':
                        # Mixed precision training for faster computation
                        with torch.cuda.amp.autocast():
                            loss_dict = model(images, targets)  # Forward pass
                            losses = sum(loss for loss in loss_dict.values())  # Sum all losses

                        # Backward pass and optimization
                        scaler.scale(losses).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        # No gradients during validation
                        with torch.no_grad():
                            loss_dict = model(images, targets)
                            losses = sum(loss for loss in loss_dict.values())

                    running_loss += losses.item()  # Accumulate loss
                    batch_count += 1

            epoch_loss = running_loss / (batch_count + 1e-8)  # Calculate average loss for the epoch
            print(f'{phase} Loss: {epoch_loss:.4f}')

            history[f'{phase}_loss'].append(epoch_loss)  # Save loss to history

            # Save the best model based on validation loss
            if phase == 'val':
                scheduler.step(epoch_loss)  # Adjust learning rate based on validation loss
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = model.state_dict().copy()  # Store best model weights

    print(f'Best val Loss: {best_loss:.4f}')
    model.load_state_dict(best_model_wts)  # Load best model weights
    return model, history  # Return the best model and the history of losses


## 10. Visualize the Result of Training Model

In [None]:
# Plot function of training history
def plot_training_history(history, model_name):
    """Visualizes the training history."""
    plt.figure(figsize=(12, 4))

    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title(f'{model_name} - Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # If IoU values are in the history, plot them
    if 'train_iou' in history:
        plt.subplot(1, 2, 2)
        plt.plot(history['train_iou'], label='Train')
        plt.plot(history['val_iou'], label='Validation')
        plt.title(f'{model_name} - IoU')
        plt.xlabel('Epoch')
        plt.ylabel('IoU')
        plt.legend()

    plt.tight_layout()
    plt.show()


## 11. Visualize Object Detection & Segmentation at Once
### Which is a benefit of using Mask R-CNN Model

In [None]:
# Function to visualize Object Detection and Segmentation results simultaneously
def visualize_maskrcnn_results(model, dataset, device, num_samples=3, threshold=0.5):
    """
    Visualize both object detection and segmentation results of Mask R-CNN
    Args:
        model: Mask R-CNN model
        dataset: Dataset
        device: Execution device (CPU/GPU)
        num_samples: Number of samples to visualize
        threshold: Object detection confidence threshold
    """
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    # Define colors for each class
    colors = {
        1: ('Parking Space', (0, 1, 0)),    # Mask color of parking space = Green
        2: ('Driveable Space', (0, 0, 1))   # Msk color of driveable space =  Blue
    }

    for idx in indices:
        image, _ = dataset[idx]
        image_input = image.unsqueeze(0).to(device)

        with torch.no_grad():
            prediction = model(image_input)[0]

        # Convert the image to a numpy array
        image_np = image.permute(1, 2, 0).numpy()

        # Set up the plot for visualizing results
        plt.figure(figsize=(15, 5))

        # 1. Original Image
        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title('Original')
        plt.axis('off')

        # 2. Segmentation Mask
        plt.subplot(1, 3, 2)
        seg_image = image_np.copy()

        # Select only the masks with high confidence
        masks = prediction['masks']
        scores = prediction['scores']
        labels = prediction['labels']

        mask_overlay = np.zeros_like(image_np)

        for mask, score, label in zip(masks, scores, labels):
            if score > threshold:
                mask = mask.squeeze().cpu().numpy()
                color = colors[label.item()][1]
                mask_overlay[mask > 0.5] = color

        # Create mask overlay
        seg_result = 0.7 * image_np + 0.3 * mask_overlay
        seg_result = np.clip(seg_result, 0, 1)
        plt.imshow(seg_result)
        plt.title('Mask-Overlay')
        plt.axis('off')

        # 3. Object Detection Results
        plt.subplot(1, 3, 3)
        det_image = image_np.copy()

        # Importing Rectangle from matplotlib for bounding boxes
        from matplotlib.patches import Rectangle

        det_ax = plt.gca()
        det_ax.imshow(det_image)

        # Draw bounding boxes
        for box, score, label in zip(prediction['boxes'], prediction['scores'], prediction['labels']):
            if score > threshold:
                box = box.cpu().numpy()
                class_name = colors[label.item()][0]
                color = colors[label.item()][1]

                # Draw the bounding box
                rect = Rectangle(
                    (box[0], box[1]),
                    box[2] - box[0],
                    box[3] - box[1],
                    linewidth=2,
                    edgecolor=color,
                    facecolor='none'
                )
                det_ax.add_patch(rect)

                # Add label
                det_ax.text(
                    box[0],
                    box[1] - 5,
                    f'{class_name}: {score:.2f}',
                    color=color,
                    fontsize=8,
                    bbox=dict(facecolor='white', alpha=0.5)
                )

        det_ax.set_title('Detection Results')
        det_ax.axis('off')

        plt.tight_layout()
        plt.show()

## 12. Train Mask RCNN Model

In [None]:
# Training Mask R-CNN
print("\nTraining Mask R-CNN...")

maskrcnn_model = get_maskrcnn_model(num_classes=3)  # 3 classes: background, parking space, driveable space
maskrcnn_model = maskrcnn_model.to(device)
maskrcnn_optimizer = torch.optim.Adam([                             # Set up diff learning rates for each structure of Mask Rcnn
    {'params': maskrcnn_model.backbone.parameters(), 'lr': 1e-4},
    {'params': maskrcnn_model.rpn.parameters(), 'lr': 1e-4},
    {'params': maskrcnn_model.roi_heads.parameters(), 'lr': 1e-4}
])

maskrcnn_model, maskrcnn_history = train_maskrcnn(
    maskrcnn_model,
    {'train': maskrcnn_train_loader, 'val': maskrcnn_val_loader},
    maskrcnn_optimizer,
    num_epochs=50,
    device=device
)

# Visualize training curve
plot_training_history(maskrcnn_history, 'Mask R-CNN')

## 13. Visualize Result of Prediction by the Mask-RCNN Model

In [None]:
# Visualizing Mask R-CNN Prediction Results
print("\nVisualizing predictions and Parking lot detection results of the Mask R-CNN model:")
visualize_maskrcnn_results(
    maskrcnn_model,
    maskrcnn_test_dataset,
    device,
    num_samples = 3,
    threshold = 0.5
)

## 14. Visualize Mask R-CNN's Detection and Segmentaion Results in a Single Image

In [None]:
import matplotlib.patches as patches

def visualize_combined_results(model, dataset, device, num_samples = 3, threshold = 0.5):
    """Visualize Mask R-CNN's segmentation and detection results in a single image"""
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    # Define colors for each class
    colors = {
        1: ('Parking Space', (0, 1, 0)),    # mask color of parking space = green
        2: ('Driveable Space', (0, 0, 1))   # mask color for driveable space = blue
    }

    for idx in indices:
        image, _ = dataset[idx]
        image_input = [image.to(device)]

        with torch.no_grad():
            prediction = model(image_input)[0]

        # Convert image to numpy array
        image_np = image.permute(1, 2, 0).numpy()

        # Set up plot for visualization
        plt.figure(figsize=(20, 6))

        # 1. Original Image
        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title('Original Image', fontsize = 12, pad = 10)
        plt.axis('off')

        # 2. Combined result of segmentation mask and bounding boxes
        plt.subplot(1, 3, 2)
        combined_image = image_np.copy()

        # Create mask overlay
        masks = prediction['masks']
        scores = prediction['scores']
        labels = prediction['labels']
        boxes = prediction['boxes']

        # Apply masks first
        mask_overlay = np.zeros_like(image_np)
        for mask_, score_, label_ in zip(masks, scores, labels):
            if score_ > threshold:
                mask_ = mask_.squeeze().cpu().numpy()
                color = colors[label_.item()][1]
                mask_overlay[mask_ > 0.5] = color

        # Combine mask and image
        combined_result = 0.7 * combined_image + 0.3 * mask_overlay
        combined_result = np.clip(combined_result, 0, 1)
        plt.imshow(combined_result)

        # Add bounding boxes and labels
        ax = plt.gca()
        for box, score_, label_ in zip(boxes, scores, labels):
            if score_ > threshold:
                box = box.cpu().numpy()
                class_name = colors[label_.item()][0]
                color = colors[label_.item()][1]

                # Draw bounding box
                rect = patches.Rectangle(
                    (box[0], box[1]),
                    box[2] - box[0],
                    box[3] - box[1],
                    linewidth = 2,
                    edgecolor = color,
                    facecolor = 'none'
                )
                ax.add_patch(rect)

                # Add label
                ax.text(
                    box[0],
                    box[1] - 5,
                    f'{class_name}: {score_:.2f}',
                    color = color,
                    fontsize = 8,
                    bbox = dict(facecolor = 'white', alpha = 0.7)
                )

        plt.title('Segmentation & Detection Result', fontsize = 12, pad = 10)
        plt.axis('off')

        # 3. Class-wise Statistics
        plt.subplot(1, 3, 3)
        class_vis = image_np.copy()
        ax = plt.gca()
        plt.imshow(class_vis)

        # Display each class with a different color
        legend_info = []
        for label_id, (class_name, color) in colors.items():
            mask_combined = np.zeros_like(mask_overlay[:,:,0])
            box_count = 0
            confidence_sum = 0

            for mask_, label_, score_, box in zip(masks, labels, scores, boxes):
                if score_ > threshold and label_.item() == label_id:
                    mask_ = mask_.squeeze().cpu().numpy()
                    mask_combined = np.logical_or(mask_combined, mask_ > 0.5)
                    box = box.cpu().numpy()
                    box_count += 1
                    confidence_sum += score_.item()

                    # Draw bounding box
                    rect = patches.Rectangle(
                        (box[0], box[1]),
                        box[2] - box[0],
                        box[3] - box[1],
                        linewidth = 2,
                        edgecolor = color,
                        facecolor = 'none'
                    )
                    ax.add_patch(rect)

            if box_count > 0:
                avg_confidence = confidence_sum / box_count
                legend_info.append(f'{class_name}: {box_count} objects (avg conf: {avg_confidence:.2f})')

        # Add legend
        ax.text(
            10, 20,
            '\n'.join(legend_info),
            fontsize = 10,
            bbox=dict(facecolor = 'white', alpha = 0.7)
        )

        plt.title('Class-wise Detection Statistics', fontsize = 12, pad = 10)
        plt.axis('off')

        plt.tight_layout()
        plt.show()

        # Print detected object count and average confidence
        print(f"\nImage {idx} Detection Statistics:")
        class_stats = {}
        for label_id, (class_name, _) in colors.items():
            class_scores = [score.item() for score, label in zip(scores, labels)
            if label.item() == label_id and score.item() > threshold]
            if class_scores:
                avg_conf = sum(class_scores) / len(class_scores)
                class_stats[class_name] = {
                    'count': len(class_scores),
                    'avg_confidence': avg_conf
                }
                print(f"{class_name}: {len(class_scores)} objects, Average confidence: {avg_conf:.3f}")

# Visualizing Mask R-CNN results
print("\nVisualizing combined Mask R-CNN results...")
visualize_combined_results(
    maskrcnn_model,
    test_dataset,
    device,
    num_samples = 3,
    threshold = 0.5
)
