# Setup

In [None]:
# # Install the necessary dependencies
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 
# !pip install datasets
# !pip install evaluate 
# !pip install albumentations
# !pip install git+https://github.com/huggingface/transformers.git

# # # We will use this to push our trained model to HF Hub
# !pip install huggingface_hub 
# !pip install torchmetrics 
# !pip install 'accelerate>=1.1.0'
# !pip install matplotlib
# !pip install pycocotools


In [1]:
# Import the necessary packages
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import json
from pathlib import Path
from PIL import Image, ImageDraw
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import albumentations as A
import numpy as np
import pandas as pd
from datasets import DatasetDict, Dataset, load_from_disk
from transformers import (
    Mask2FormerConfig,
    Mask2FormerImageProcessor,
    Mask2FormerModel,
    Mask2FormerForUniversalSegmentation,
    Trainer,
    HfArgumentParser,
    TrainingArguments,
)
from transformers.trainer import EvalPrediction
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import evaluate
from huggingface_hub import notebook_login
from dataclasses import dataclass, field
from typing import Any, Optional
import logging
import transformers
import sys

  from .autonotebook import tqdm as notebook_tqdm


# Load Dataset

In [None]:
# Load COCO-style annotations from the 'buildings' dataset and convert to the Instance Segmentations format

def coco2seg(dataset_dir, splits=['train', 'val', 'test']):
    """
    Convert a COCO-style JSON (images, annotations, categories) to
    a Instance Segmentation Dataset compatible format for Hugging Face Tasks.
    
    Args:
        dataset_dir: path to the dataset
        splits: splits to load and convert
    
    Returns:
        DatasetDict object (!!!images saved as paths for memory usage!!!):
            - 'image': PIL.Image
            - 'annotation': PIL.Image with
                            R channel = category_id
                            G channel = instance_id (unique per image, <256 instances)
    """
    dataset_dir = Path(dataset_dir)
    result = {}

    for split in splits:
        ann_path = dataset_dir / f"{split}/{split}_512.json"
        img_dir = dataset_dir / split / "image_512"

        if not ann_path.exists() or not img_dir.exists():
            if split == "test" and img_dir.exists():
                print(f"⚠️ No annotation file found for '{split}' — loading images only.")
                images = []
                for f in img_dir.glob("*.tif"):
                    images.append({"image": str(f)})
                
                result[split] = Dataset.from_list(images)
                continue
            print(f"WARNING: Missing split '{split}', skipping.")
            continue

        print(f"Processing split '{split}'...")

        # Load COCO annotation JSON
        with open(ann_path, "r") as f:
            coco = json.load(f)

        images = {img["id"]: img for img in coco["images"]}
        annotations = coco["annotations"]

        # Group annotations by image_id
        anns_by_img = {}
        for ann in annotations:
            anns_by_img.setdefault(ann["image_id"], []).append(ann)

        records = []
        for img_id, img_info in tqdm(images.items()):
            file_name = Path(img_info["file_name"]).name
            
            width, height = img_info["width"], img_info["height"]

            image_path = img_dir / file_name

            if not image_path.exists():
                continue

            # Create blank annotation image (2-channel RGB)
            ann_img = np.zeros((height, width, 3), dtype=np.uint8)

            # Draw polygons per instance
            r = Image.new("L", (width, height), 0)  # Category
            g = Image.new("L", (width, height), 0)  # Instance
            draw_r = ImageDraw.Draw(r)
            draw_g = ImageDraw.Draw(g)

            anns = anns_by_img.get(img_id, [])
            instance_counter = 1
            cat_ids = []
            for ann in anns:
                
                cat_id = int(ann["category_id"])
                cat_ids.append(cat_id)
                polygons = ann.get("segmentation", [])
                if not polygons or not isinstance(polygons, list):
                    continue

                # Each polygon in COCO is a list of [x1, y1, x2, y2, ...]
                for poly in polygons:
                    if len(poly) < 6:  # invalid polygon
                        continue
                    xy = [(poly[i], poly[i + 1]) for i in range(0, len(poly), 2)]
                    draw_r.polygon(xy, fill=cat_id+1)
                    draw_g.polygon(xy, fill=instance_counter)

                instance_counter += 1
                if instance_counter >= 256:
                    print(f"WARNING: Too many instances in {file_name}, clipping to 255.")
                    break

            # Merge R and G channels back into RGB
            ann_img = np.stack([
                np.array(r), 
                np.array(g), 
                np.zeros((height, width), np.uint8)], 
                axis=-1)

            # SAVE AS PNG SUPER IMPORTANT FOR NO DATA LOSS
            ann_path = dataset_dir / split / "annotation" / f"{Path(img_info['file_name']).stem}.png"
            ann_path.parent.mkdir(parents=True, exist_ok=True)
            ann_img = ann_img.astype(np.uint8)
            Image.fromarray(ann_img).save(ann_path)

            records.append({
                "image": str(image_path),
                "annotation": str(ann_path)
            })

        result[split] = Dataset.from_list(records)

    dataset = DatasetDict(result)

    dataset.save_to_disk(dataset_dir / "hf")

    return dataset

DATASET_DIR = Path("./building-extraction-generalization-2024")

dataset = coco2seg(DATASET_DIR)

In [3]:
# Upload Dataset to Hugging Face HUB
from datasets import Dataset, Features, Image as ImageHF
from huggingface_hub import login

login('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') 

import glob
from PIL import Image

# --- Define generator factory functions ---
def make_gen_examples_train_val(images, annotations):
    def gen():
        for img_path, ann_path in zip(images, annotations):
            yield {
                "image": {"path": img_path},
                "annotation": {"path": ann_path},
            }
    return gen

# --- Define consistent feature schemas ---
features = Features({
    "image": ImageHF(),
    "annotation": ImageHF()
})


# --- Build each split independently ---
# Train
train_images = sorted(glob.glob("./building-extraction-generalization-2024/train/image_512/*.jpg"))
train_anns = sorted(glob.glob("./building-extraction-generalization-2024/train/annotation/*.png"))
train_ds = Dataset.from_generator(make_gen_examples_train_val(train_images, train_anns),
                                  features=features)

# Validation
val_images = sorted(glob.glob("./building-extraction-generalization-2024/val/image_512/*.jpg"))
val_anns = sorted(glob.glob("./building-extraction-generalization-2024/val/annotation/*.png"))
val_ds = Dataset.from_generator(make_gen_examples_train_val(val_images, val_anns),
                                features=features)

# Test
test_images = sorted(glob.glob("./building-extraction-generalization-2024/test/image_512/*.tif"))
test_anns = sorted(glob.glob("./building-extraction-generalization-2024/test/image_512/*.tif"))
test_ds = Dataset.from_generator(make_gen_examples_train_val(test_images, test_anns),
                                 features=features)


ds_dict = DatasetDict({"train": train_ds, "val": val_ds, "test": test_ds})
ds_dict.push_to_hub("tomascanivari/building_extraction")


Generating train split: 1000 examples [00:00, 23452.30 examples/s]
Map: 100%|██████████| 3784/3784 [00:01<00:00, 2212.73 examples/s]ards/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 10.56ba/s]
Processing Files (1 / 1): 100%|██████████|  288MB /  288MB,  237MB/s  

  [2m2025-10-20T15:41:38.032081Z[0m [33m WARN[0m  [33mStatus Code: 500. Retrying..., [1;33mrequest_id[0m[33m: ""[0m
    [2;3mat[0m /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:227



Processing Files (1 / 1): 100%|██████████|  288MB /  288MB, 33.0MB/s  
New Data Upload: |          |  0.00B /  0.00B,  0.00B/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:14<00:00, 14.08s/ shards]
Map: 100%|██████████| 933/933 [00:00<00:00, 2094.86 examples/s]shards/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  9.77ba/s]
Processing Files (1 / 1): 100%|██████████| 71.4MB / 71.4MB, 52.8MB/s  
New Data Upload: |          |  0.00B /  0.00B,  0.00B/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.73s/ shards]
Map: 100%|██████████| 250/250 [00:00<00:00, 283.20 examples/s] shards/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00,  5.14ba/s]
Processing Files (0 / 1):  90%|█████████ |  348MB /  387MB, 6.56MB/s  
New Data Upload: 100%|██████████|  117MB /  117MB, 4.01MB/s  
Map: 100%|██████████| 250/250 [00:00<00:00, 449.36 examples/s]3, 27.96s/ shards]
Creating parquet from Arrow format: 100%|██████████| 4/4

CommitInfo(commit_url='https://huggingface.co/datasets/tomascanivari/building_extraction/commit/3cda4b2efb754cb0d17ce56e1f500fa99aa971b1', commit_message='Upload dataset', commit_description='', oid='3cda4b2efb754cb0d17ce56e1f500fa99aa971b1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/tomascanivari/building_extraction', endpoint='https://huggingface.co', repo_type='dataset', repo_id='tomascanivari/building_extraction'), pr_revision=None, pr_num=None)

In [None]:
from datasets import load_dataset
# Load Converted Dataset
DATASET_HF_DIR = "tomascanivari/building_extraction"

dataset = load_dataset("tomascanivari/building_extraction")

print(dataset)

# Let's check first train image and annotation
example = dataset["train"][0]
img = example["image"]
ann = example["annotation"]

# Load PIL image
image = np.array(img.convert("RGB"))
annotation = np.array(ann)

print("Number of Categories: ", np.unique(annotation[..., 0]))  # Red channel: category IDs
print("Number of Instances: ", np.unique(annotation[..., 1]))  # Green channel: instance IDs

# Plot the original image and the annotations
plt.figure(figsize=(15, 5))
for plot_index in range(3):
    if plot_index == 0:
        # If plot index is 0 display the original image
        plot_image = image
        title = "Original"
    else:
        # Else plot the annotation maps
        plot_image = annotation[..., plot_index - 1]
        title = ["Class Map (R)", "Instance Map (G)"][plot_index - 1]
    # Plot the image
    plt.subplot(1, 3, plot_index + 1)
    plt.imshow(plot_image)
    plt.title(title)
    plt.axis("off")

# Let' check instance 0
print("Instance 1")
mask = (annotation[..., 1] == 1)
visual_mask = (mask * 255).astype(np.uint8)
Image.fromarray(visual_mask)

In [None]:
from transformers import (MaskFormerConfig, MaskFormerForInstanceSegmentation, MaskFormerImageProcessor)

# Change label2id and id2label (Start from 0, when in annotations it starts from 1. Compatible with reduce in Processor)
id2label = {0: 'building'}
label2id = {'building': 0}

# Load pre-trained weights
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco", id2label=id2label,
                                                          ignore_mismatched_sizes=True)
# Load processor
processor = MaskFormerImageProcessor(
    do_reduce_labels=True,
    size=(512, 512),
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=False,
)

In [None]:
from torch.utils.data import Dataset # To make ImageSegmentationDataset a PyTorch dataset !!!!!!!!!!!!!!!!

# Define the configurations of the transforms specific
# to the dataset used
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255
# Build the augmentation transforms
train_val_transform = A.Compose([
    A.Resize(width=512, height=512),
    A.HorizontalFlip(p=0.3),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
    A.ToFloat()
])

class ImageSegmentationDataset(Dataset):
    def __init__(self, dataset, processor, transform=None):
        # Initialize the dataset, processor, and transform variables
        self.dataset = dataset
        self.processor = processor
        self.transform = transform
        
    def __len__(self):
        # Return the number of datapoints
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Convert the PIL Image to a NumPy array
        image = np.array(self.dataset[idx]["image"].convert("RGB"))

        # Get the pixel wise instance id and category id maps
        # of shape (height, width)
        annotation = np.array(self.dataset[idx]["annotation"])
        instance_seg = np.array(annotation)[..., 1]
        class_id_map = np.array(annotation)[..., 0]
        class_labels = np.unique(class_id_map)
        
        # Build the instance to class dictionary
        inst2class = {}
        for label in class_labels:
            instance_ids = np.unique(instance_seg[class_id_map == label])
            inst2class.update({i: label for i in instance_ids})
        # Apply transforms
        if self.transform is not None:
            transformed = self.transform(image=image, mask=instance_seg)
            (image, instance_seg) = (transformed["image"], transformed["mask"])
            
            # Convert from channels last to channels first
            image = image.transpose(2,0,1)
        if class_labels.shape[0] == 1 and class_labels[0] == 0:
            # If the image has no objects then it is skipped
            inputs = self.processor([image], return_tensors="pt")
            inputs = {k:v.squeeze() for k,v in inputs.items()}
            inputs["class_labels"] = torch.tensor([0])
            inputs["mask_labels"] = torch.zeros(
                (0, inputs["pixel_values"].shape[-2], inputs["pixel_values"].shape[-1])
            )
        else:
            # Else use process the image with the segmentation maps
            inputs = self.processor(
                [image],
                [instance_seg],
                instance_id_to_semantic_id=inst2class,
                return_tensors="pt"
            )
            inputs = {
                k:v.squeeze() if isinstance(v, torch.Tensor) else v[0] for k,v in inputs.items()
            }
        # Return the inputs
        return inputs

# Build the train and validation instance segmentation dataset
train_dataset = ImageSegmentationDataset(
    dataset["train"],
    processor=processor,
    transform=train_val_transform
)
val_dataset = ImageSegmentationDataset(
    dataset["val"],
    processor=processor,
    transform=train_val_transform
)

In [None]:
# Check if everything is preprocessed correctly
print("Train Instance 0")
inputs = train_dataset[0]
for k,v in inputs.items():
  print(k, v.shape)

print("\nTrain Instance 1")
inputs = train_dataset[1]
for k,v in inputs.items():
  print(k, v.shape)

In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([example["pixel_values"] for example in batch])
    pixel_mask = torch.stack([example["pixel_mask"] for example in batch])
    class_labels = [example["class_labels"] for example in batch]
    mask_labels = [example["mask_labels"] for example in batch]
    return {"pixel_values": pixel_values, "pixel_mask": pixel_mask, "class_labels": class_labels, "mask_labels": mask_labels}

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=1, 
    shuffle=True, 
    collate_fn=collate_fn)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=1, 
    shuffle=False, 
    collate_fn=collate_fn,
)

In [None]:
# Check if batching is correct
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k,v.shape)
  else:
    print(k,len(v))

In [None]:
outputs = model(
          pixel_values=batch["pixel_values"],
          mask_labels=batch["mask_labels"],
          class_labels=batch["class_labels"],
      )
outputs.loss

In [None]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model.to(device)

# Initialize Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# Set number of epochs and batch size
num_epochs = 2
for epoch in range(num_epochs):
    print(f"Epoch {epoch} | Training")
    
    # Set model in training mode 
    model.train()
    train_loss, val_loss = [], []
    
    
    # Training loop
    for idx, batch in enumerate(tqdm(train_dataloader)):
        # Reset the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
            class_labels=[labels.to(device) for labels in batch["class_labels"]],
        )
        # Backward propagation
        loss = outputs.loss
        train_loss.append(loss.item())
        loss.backward()
        # if idx % 50 == 0:
            # print("  Training loss: ", round(sum(train_loss)/len(train_loss), 6))
        
        # Optimization
        optimizer.step()
    
    # Average train epoch loss
    train_loss = sum(train_loss)/len(train_loss)
    
    # Set model in evaluation mode
    model.eval()
    start_idx = 0
    print(f"Epoch {epoch} | Validation")
    for idx, batch in enumerate(tqdm(val_dataloader)):
        with torch.no_grad():
            # Forward pass
            outputs = model(
                pixel_values=batch["pixel_values"].to(device),
                mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
                class_labels=[labels.to(device) for labels in batch["class_labels"]],
            )
            # Get validation loss
            loss = outputs.loss
            val_loss.append(loss.item())
            # if idx % 50 == 0:
                # print("  Validation loss: ", round(sum(val_loss)/len(val_loss), 6))
    
    # Average validation epoch loss
    val_loss = sum(val_loss)/len(val_loss)
    
    # Print epoch losses
    print(f"Epoch {epoch} | train_loss: {train_loss} | validation_loss: {val_loss}")

In [None]:
# model.save_pretrained("models/mf")
# processor.save_pretrained("models/mf_p")
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor

model = MaskFormerForInstanceSegmentation.from_pretrained("models/mf")
processor = MaskFormerImageProcessor.from_pretrained("models/mf_p")

# We won't be using albumentations to preprocess images for inference
processor.do_normalize = True
processor.do_resize = True
processor.do_rescale = True

# Push your model and preprocessor to the Hub
model.push_to_hub("maskformer-swin-base-building-instance")
processor.push_to_hub("maskformer-swin-base-building-instance")

