In [None]:
import os, json, random
import numpy as np
import cv2
import torch
from tqdm import tqdm
from shapely import wkt as shapely_wkt
import random
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_TYPE = "vit_b"
SAM_CKPT   = "sam_vit_b_01ec64.pth"
FT_CKPT    = "sam_xbd_maskdecoder_best.pth"

sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CKPT).to(DEVICE)
sam.eval()

transform = ResizeLongestSide(sam.image_encoder.img_size)
print("Loaded Zero Shot SAM")


Loaded Zero Shot SAM


In [3]:
def polygon_to_mask(poly, h, w):
    mask = np.zeros((h, w), dtype=np.uint8)
    ext = np.array(list(poly.exterior.coords), dtype=np.float32)
    ext[:, 0] = np.clip(ext[:, 0], 0, w - 1)
    ext[:, 1] = np.clip(ext[:, 1], 0, h - 1)
    ext_i = np.round(ext).astype(np.int32)
    cv2.fillPoly(mask, [ext_i], 1)
    for ring in poly.interiors:
        hole = np.array(list(ring.coords), dtype=np.float32)
        hole[:, 0] = np.clip(hole[:, 0], 0, w - 1)
        hole[:, 1] = np.clip(hole[:, 1], 0, h - 1)
        hole_i = np.round(hole).astype(np.int32)
        cv2.fillPoly(mask, [hole_i], 0)
    return mask

def bbox_from_mask(m):
    ys, xs = np.where(m > 0)
    if len(xs) == 0:
        return None
    return np.array([xs.min(), ys.min(), xs.max(), ys.max()], dtype=np.float32)

def iou_dice(pred, gt, eps=1e-6):
    pred = (pred > 0).astype(np.uint8)
    gt   = (gt > 0).astype(np.uint8)
    inter = (pred & gt).sum()
    union = (pred | gt).sum()
    iou = (inter + eps) / (union + eps)
    dice = (2*inter + eps) / (pred.sum() + gt.sum() + eps)
    return float(iou), float(dice)

def boundary_map(mask):
    # 1-pixel boundary using morphological gradient
    mask = (mask > 0).astype(np.uint8)
    k = np.ones((3,3), np.uint8)
    dil = cv2.dilate(mask, k, iterations=1)
    ero = cv2.erode(mask, k, iterations=1)
    b = (dil - ero) > 0
    return b.astype(np.uint8)

def boundary_f1(pred, gt, tol=2):
    pred_b = boundary_map(pred)
    gt_b   = boundary_map(gt)

    if pred_b.sum() == 0 and gt_b.sum() == 0:
        return 1.0
    if pred_b.sum() == 0 or gt_b.sum() == 0:
        return 0.0

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*tol+1, 2*tol+1))
    gt_dil   = cv2.dilate(gt_b, kernel, iterations=1)
    pred_dil = cv2.dilate(pred_b, kernel, iterations=1)

    # precision: pred boundary matched by dilated gt boundary
    prec = (pred_b & gt_dil).sum() / (pred_b.sum() + 1e-6)
    # recall: gt boundary matched by dilated pred boundary
    rec  = (gt_b & pred_dil).sum() / (gt_b.sum() + 1e-6)

    if prec + rec == 0:
        return 0.0
    return float(2 * prec * rec / (prec + rec))


In [4]:
@torch.no_grad()
def sam_predict_mask_from_bbox(img_rgb, bbox_xyxy):
    h, w = img_rgb.shape[:2]

    resized = transform.apply_image(img_rgb)
    input_t = torch.as_tensor(resized, device=DEVICE).permute(2,0,1)[None, ...].contiguous()
    input_t = sam.preprocess(input_t)
    resized_hw = resized.shape[:2]

    image_embedding = sam.image_encoder(input_t)

    box = transform.apply_boxes(np.array(bbox_xyxy, dtype=np.float32)[None, :], (h, w))
    box_t = torch.as_tensor(box, dtype=torch.float32, device=DEVICE)

    sparse, dense = sam.prompt_encoder(points=None, boxes=box_t, masks=None)

    low_res_masks, _ = sam.mask_decoder(
        image_embeddings=image_embedding,
        image_pe=sam.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse,
        dense_prompt_embeddings=dense,
        multimask_output=False,
    )

    up = sam.postprocess_masks(low_res_masks, input_size=resized_hw, original_size=(h, w))
    prob = torch.sigmoid(up)[0,0].detach().cpu().numpy()
    return (prob > 0.5).astype(np.uint8)


In [None]:
IMAGE_DIR = r"E:\Nana\test\images"
LABEL_DIR = r"E:\Nana\test\labels"

def list_images(d):
    exts = (".png", ".jpg", ".jpeg", ".tif", ".tiff")
    fs = [f for f in os.listdir(d) if f.lower().endswith(exts)]
    fs.sort()
    return fs

print(len(list_images(IMAGE_DIR)))
results = []
files = [f for f in list_images(IMAGE_DIR) if "pre_disaster" in f.lower()]  # pre-only
print(len(files))

# Use a small number of test file due to computational limitations
half = len(files) // 16
random.shuffle(files)
files = files[:half]
print(len(files))

In [17]:


for fname in tqdm(files, desc="Evaluating images"):
    base = os.path.splitext(fname)[0]
    img_path = os.path.join(IMAGE_DIR, fname)
    js_path  = os.path.join(LABEL_DIR, base + ".json")

    if not os.path.exists(js_path):
        continue

    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    if img is None:
        continue
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]

    with open(js_path, "r") as f:
        data = json.load(f)

    feats = data.get("features", {}).get("xy", [])
    for feat in feats:
        wkt_str = feat.get("wkt", None)
        uid = feat.get("properties", {}).get("uid", "")

        if not wkt_str:
            continue

        geom = shapely_wkt.loads(wkt_str)
        polys = [geom] if geom.geom_type == "Polygon" else (list(geom.geoms) if geom.geom_type=="MultiPolygon" else [])
        for poly in polys:
            gt = polygon_to_mask(poly, h, w)
            if gt.sum() == 0:
                continue

            bbox = bbox_from_mask(gt)
            if bbox is None:
                continue

            pred = sam_predict_mask_from_bbox(img, bbox)

            iou, dice = iou_dice(pred, gt)
            # bf1 = boundary_f1(pred, gt, tol=2)

            results.append({
                "image": fname,
                "uid": uid,
                "iou": iou,
                "dice": dice,
                # "boundary_f1": bf1,
                "gt_area": int(gt.sum())
            })

print("Instances evaluated:", len(results))
print("Mean IoU:", np.mean([r["iou"] for r in results]) if results else None)
print("Mean Dice:", np.mean([r["dice"] for r in results]) if results else None)
# print("Mean Boundary F1:", np.mean([r["boundary_f1"] for r in results]) if results else None)


Evaluating images: 100%|██████████| 58/58 [18:36<00:00, 19.25s/it]  

Instances evaluated: 3635
Mean IoU: 0.7580287116919527
Mean Dice: 0.857669089743391



