In [12]:
import torch
import numpy as np
import cv2
from PIL import Image
import albumentations as A
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
import matplotlib.pyplot as plt

In [6]:
CONFIG = {
    "MODEL_CHECKPOINT": "facebook/mask2former-swin-tiny-coco-instance",
    "WEIGHTS_PATH": "mask2former.pth",
    "INFER_SIZE": 512,     # model input
    "ORIG_SIZE": 2048,    # visualization size
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
}

ROAD_CLASS_ID = 3


In [None]:
processor = Mask2FormerImageProcessor.from_pretrained(
    CONFIG["MODEL_CHECKPOINT"],
    do_resize=False,
    do_rescale=True,
    do_normalize=True
)

model = Mask2FormerForUniversalSegmentation.from_pretrained(
    CONFIG["MODEL_CHECKPOINT"],
    num_labels=7,
    ignore_mismatched_sizes=True
)

state_dict = torch.load(CONFIG["WEIGHTS_PATH"], map_location="cpu")
model.load_state_dict(state_dict)
model.to(CONFIG["DEVICE"])
model.eval()

print("Model loaded")


In [8]:
resize_to_model = A.Compose([
    A.Resize(CONFIG["INFER_SIZE"], CONFIG["INFER_SIZE"])
])


In [9]:
def predict_road_mask_2048(image_path):
    # Load original image (2048×2048)
    orig_image = np.array(Image.open(image_path).convert("RGB"))
    h, w = orig_image.shape[:2]

    assert h == CONFIG["ORIG_SIZE"] and w == CONFIG["ORIG_SIZE"], \
        "Image must be 2048×2048"

    # Resize for model
    resized = resize_to_model(image=orig_image)["image"]

    # HF processor
    inputs = processor(
        images=resized,
        return_tensors="pt"
    ).to(CONFIG["DEVICE"])

    with torch.no_grad():
        outputs = model(**inputs)

    # Semantic segmentation
    semantic_map = processor.post_process_semantic_segmentation(
        outputs,
        target_sizes=[(CONFIG["INFER_SIZE"], CONFIG["INFER_SIZE"])]
    )[0].cpu().numpy()

    # Road vs Others
    road_mask_small = (semantic_map == ROAD_CLASS_ID).astype(np.uint8)

    # Upscale to 2048×2048
    road_mask = cv2.resize(
        road_mask_small,
        (CONFIG["ORIG_SIZE"], CONFIG["ORIG_SIZE"]),
        interpolation=cv2.INTER_NEAREST
    )

    return orig_image, road_mask


In [13]:
def visualize_matplotlib(orig_image, road_mask):
    """
    orig_image : RGB image (H, W, 3)
    road_mask  : binary mask (H, W), 1 = road
    """

    # Create overlay (RGB)
    overlay = orig_image.copy()
    overlay[road_mask == 1] = [255, 0, 0]  # RED roads in RGB

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    #Original Image
    axes[0].imshow(orig_image)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    #Road Mask
    axes[1].imshow(road_mask, cmap="gray")
    axes[1].set_title("Road Mask")
    axes[1].axis("off")

    #Overlay
    axes[2].imshow(overlay)
    axes[2].set_title("Road Overlay")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
image_path = "roads/tile_10240_47104.tif"

orig_image, road_mask = predict_road_mask_2048(image_path)
visualize_matplotlib(orig_image, road_mask)


In [16]:
import os
import matplotlib.pyplot as plt
def run_folder_inference(
    folder_path,
    extensions=(".tif", ".tiff"),
    pause=False
):
    """
    Loops through folder, runs road segmentation, and visualizes results.

    pause=True  → waits for key press between images
    pause=False → auto-advance
    """

    image_files = sorted([
        os.path.join(folder_path, f)
        for f in os.listdir(folder_path)
        if f.lower().endswith(extensions)
    ])

    print(f"Found {len(image_files)} images")

    for idx, image_path in enumerate(image_files, 1):
        print(f"\n[{idx}/{len(image_files)}] Processing: {os.path.basename(image_path)}")

        try:
            orig_image, road_mask = predict_road_mask_2048(image_path)
        except Exception as e:
            print(f"Failed on {image_path}: {e}")
            continue

        visualize_matplotlib(orig_image, road_mask)

        if pause:
            input("Press ENTER to continue (Ctrl+C to stop)...")


In [None]:
folder_path = "roads/"

run_folder_inference(folder_path)
