In [None]:
import torch
import random
import evaluate

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    MaskFormerForInstanceSegmentation,
    MaskFormerImageProcessor
)


# Load Model and PreProcessor

In [None]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Grab the trained model and processor from the hub
model = MaskFormerForInstanceSegmentation.from_pretrained(
    "tomascanivari/maskformer-swin-base-building-instance"
).to(device)

processor = MaskFormerImageProcessor.from_pretrained(
    "tomascanivari/maskformer-swin-base-building-instance")


# Load Dataset

In [None]:
# Load Converted Dataset
DATASET_HF_DIR = "tomascanivari/building_extraction"

# Load the whole dataset dict
dataset = load_dataset("tomascanivari/building_extraction")

# Test Split Annotation is Place-Holder
print(dataset)

In [None]:
# Let's check first train image and annotation
example = dataset["train"][0]
img = example["image"]
ann = example["annotation"]

# Load PIL image
image = np.array(img.convert("RGB"))
annotation = np.array(ann)

print("Number of Categories: ", np.unique(annotation[..., 0]))  # Red channel: category IDs
print("Number of Instances: ", np.unique(annotation[..., 1]))  # Green channel: instance IDs

# Plot the original image and the annotations
plt.figure(figsize=(15, 5))
for plot_index in range(3):
    if plot_index == 0:
        # If plot index is 0 display the original image
        plot_image = image
        title = "Original"
    else:
        # Else plot the annotation maps
        plot_image = annotation[..., plot_index - 1]
        title = ["Class Map (R)", "Instance Map (G)"][plot_index - 1]
    # Plot the image
    plt.subplot(1, 3, plot_index + 1)
    plt.imshow(plot_image)
    plt.title(title)
    plt.axis("off")

# Let' check instance 0
print("Instance 1")
mask = (annotation[..., 1] == 1)
visual_mask = (mask * 255).astype(np.uint8)
Image.fromarray(visual_mask)

In [None]:
# Visualize RLE obtained from instance segmentation annotation
from pycocotools import mask as mask_utils

def instance_mask_to_rle(instance_mask):
    """
    annotation_image: H x W x 3
        - red channel: instance ID
        - green channel: class label
    Returns:
        List of RLEs (one per instance)
    """
    instance_ids = np.unique(instance_mask)
    instance_ids = instance_ids[instance_ids != 255]  # exclude background
    
    rles = []
    for inst_id in instance_ids:
        mask = (instance_mask == inst_id).astype(np.uint8)
        rle = mask_utils.encode(np.asfortranarray(mask))
        rle["counts"] = rle["counts"].decode("utf-8")  # optional for JSON compatibility
        rles.append(rle)
    return rles

def visualize_rle_on_image(image, rle_list, alpha=0.5):
    """
    Visualize RLE masks over the original image.
    
    Args:
        image: H x W x 3 NumPy array (original image)
        rle_list: list of RLEs (from pycocotools)
        alpha: transparency for mask overlay
    """
    overlay = image.copy()
    
    for rle in rle_list:
        mask = mask_utils.decode(rle)  # H x W, 0/1
        color = np.array([0, 0, 255], dtype=np.uint8)
        overlay[mask==1] = (1-alpha)*overlay[mask==1] + alpha*color
    
    plt.figure(figsize=(10, 10))
    plt.imshow(overlay)
    plt.axis('off')
    plt.show()

idx = 1

image = np.array(dataset["val"][idx]["image"].convert("RGB"))

annotation = np.array(dataset["val"][idx]["annotation"])
annotation -= 1  # Reduce labels
annotation[annotation == -1] = 255  # ignore_index

rles = instance_mask_to_rle(annotation[..., 1])

print(len(rles), len(np.unique(annotation[..., 1]))-1)

visualize_rle_on_image(image, rles)

# Plot the original image and the annotations
plt.figure(figsize=(15, 5))
for plot_index in range(3):
    if plot_index == 0:
        # If plot index is 0 display the original image
        plot_image = image
        title = "Original"
    else:
        # Else plot the annotation maps
        plot_image = annotation[..., plot_index - 1]
        title = ["Class Map (R)", "Instance Map (G)"][plot_index - 1]
    # Plot the image
    plt.subplot(1, 3, plot_index + 1)
    plt.imshow(plot_image)
    plt.title(title)
    plt.axis("off")


In [None]:
# Let's check the prediction on the first test image

image = dataset["val"][idx]["image"].convert("RGB")
target_size = image.size[::-1]

# Preprocess image
inputs = processor(images=image, return_tensors="pt").to(device)

# Inference
model.eval()
with torch.no_grad():
    outputs = model(**inputs)

# Let's print the items returned by our model and their shapes
print("Outputs...")
for key, value in outputs.items():
    print(f"  {key}: {value.shape}")

# Post-process results to retrieve instance segmentation maps
result = processor.post_process_instance_segmentation(
    outputs,
    threshold=0.5,
    target_sizes=[target_size], 
)[0] # we pass a single output therefore we take the first result (single)

instance_seg_mask = result["segmentation"].cpu().detach().numpy()
instance_seg_mask[instance_seg_mask == -1] = 255

# for i in range(instance_seg_mask.shape[0]):
#     for j in range(instance_seg_mask.shape[1]):
#         print(instance_seg_mask[i][j], end=' ')
#     print()

rles = instance_mask_to_rle(instance_seg_mask)

print(len(rles), len(np.unique(instance_seg_mask)-1))

visualize_rle_on_image(np.array(image), rles)


print(f"Final mask shape: {instance_seg_mask.shape}")
print("Segments Information...")
for info in result["segments_info"]:
    print(f"  {info}")

print(np.unique(instance_seg_mask))

In [None]:
gt_rles = {}
pred_rles = {}

In [None]:
# Load Mean IoU metric
metrics = evaluate.load("mean_iou")

# Set model in evaluation mode
model.eval()
model.to(device)

# Test set doesn't have annotations, so we use validation
ground_truths, preds = [], []

for idx in tqdm(range(len(dataset["val"]))):
    data = dataset["val"][idx]
    image = data["image"].convert("RGB")
    target_size = image.size[::-1]

    # Ground truth semantic segmentation map
    annotation = np.array(data["annotation"])[:, :, 1]  # make sure indexing is per-sample
    annotation -= 1  # Reduce labels
    annotation[annotation == -1] = 255  # ignore_index
    ground_truths.append(annotation)

    # Preprocess image
    inputs = processor(images=image, return_tensors="pt").to(device)

    # Inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process semantic segmentation
    result = processor.post_process_semantic_segmentation(
        outputs, target_sizes=[target_size]
    )[0]
    semantic_seg_mask = result.cpu().numpy()
    preds.append(semantic_seg_mask)

# Compute metric
results = metrics.compute(
    predictions=preds,
    references=ground_truths,
    num_labels=2,
    ignore_index=255
)

print(f"Mean IoU: {results['mean_iou']} | Mean Accuracy: {results['mean_accuracy']} | Overall Accuracy: {results['overall_accuracy']}")
