In [None]:
!pip install datasets

In [None]:
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor
import torch, os, monai
import numpy as np
import cv2
from PIL import Image
from sklearn.model_selection import train_test_split
from transformers import SamProcessor, SamModel
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [None]:
import os

TEST_IMAGES_DIR = "/content/drive/MyDrive/Kuliah/Skripsi S1/test/images"
TEST_MASKS_DIR  = "/content/drive/MyDrive/Kuliah/Skripsi S1/test/targets"

all_test_imgs = sorted(
    f for f in os.listdir(TEST_IMAGES_DIR)
    if os.path.isfile(os.path.join(TEST_IMAGES_DIR, f)) and f.endswith(".png")
)

print("Total test images:", len(all_test_imgs))

sam_iou  = 0.0
sam_dice = 0.0
sam_bf   = 0.0
n_used   = 0

rng = np.random.default_rng(42)

Total test images: 1867


## Zero-shot SAM

This is the baseline VIT-B SAM model being evaluated.

In [None]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

sam_checkpoint = "/content/sam_vit_b_01ec64.pth"
model_type = "vit_b"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(DEVICE)
sam.eval()
predictor = SamPredictor(sam)

TEST_IMAGES_DIR = "/content/drive/MyDrive/Kuliah/Skripsi S1/test/images"
TEST_MASKS_DIR  = "/content/drive/MyDrive/Kuliah/Skripsi S1/test/targets"

all_test_imgs = sorted(
    f for f in os.listdir(TEST_IMAGES_DIR)
    if os.path.isfile(os.path.join(TEST_IMAGES_DIR, f)) and f.endswith(".png")
)

print("Total test images found:", len(all_test_imgs))

Total test images found: 1867


In [None]:
def iou_np(pred, gt, eps=1e-7):
    pred = pred.astype(bool)
    gt   = gt.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    return (inter + eps) / (union + eps)

def dice_np(pred, gt, eps=1e-7):
    pred = pred.astype(bool)
    gt   = gt.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    return (2 * inter + eps) / (pred.sum() + gt.sum() + eps)

def boundary_f_score_np(pred, gt, dilation_ratio=0.02, eps=1e-7):
    pred = pred.astype(np.uint8)
    gt   = gt.astype(np.uint8)

    h, w = pred.shape
    if pred.max() == 0 and gt.max() == 0:
        return 1.0

    kernel = np.ones((3, 3), np.uint8)
    pred_erode = cv2.erode(pred, kernel, iterations=1)
    gt_erode   = cv2.erode(gt,   kernel, iterations=1)
    pred_b = pred - pred_erode
    gt_b   = gt   - gt_erode

    tol = max(1, int(round(dilation_ratio * max(h, w))))
    dil_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*tol+1, 2*tol+1))
    pred_dil = cv2.dilate(pred_b, dil_kernel)
    gt_dil   = cv2.dilate(gt_b,   dil_kernel)

    pred_match = np.logical_and(pred_b > 0, gt_dil > 0)
    gt_match   = np.logical_and(gt_b   > 0, pred_dil > 0)

    pred_b_sum = (pred_b > 0).sum()
    gt_b_sum   = (gt_b   > 0).sum()
    pred_match_sum = pred_match.sum()
    gt_match_sum   = gt_match.sum()

    if pred_b_sum == 0 and gt_b_sum == 0:
        return 1.0
    if pred_b_sum == 0 or gt_b_sum == 0:
        return 0.0

    precision = (pred_match_sum + eps) / (pred_b_sum + eps)
    recall    = (gt_match_sum   + eps) / (gt_b_sum   + eps)
    return (2 * precision * recall + eps) / (precision + recall + eps)


In [None]:

MAX_TILES = 200

sam_iou  = 0.0
sam_dice = 0.0
sam_bf   = 0.0
n_used   = 0

for i, fname in enumerate(tqdm(all_test_imgs, desc="SAM bbox test")):
    if i >= MAX_TILES:
        break

    img_path  = os.path.join(TEST_IMAGES_DIR, fname)

    # handling weird duplicate files
    base = fname[:-4]
    base_clean = base.replace(" (1)", "")
    mask_name = base_clean + "_target.png"
    mask_path = os.path.join(TEST_MASKS_DIR, mask_name)

    if not os.path.exists(mask_path):
        print("Missing mask for:", fname)
        continue

    image_bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
    if image_bgr is None:
        continue
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

    gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if gt_mask is None:
        continue

    gt_bin = (gt_mask > 0).astype(np.uint8)
    if gt_bin.sum() == 0:
        continue

    H, W = gt_bin.shape
    max_side = max(H, W)
    if max_side > 512:
        scale = 512.0 / max_side
        new_w = int(round(W * scale))
        new_h = int(round(H * scale))

        image_rgb = cv2.resize(image_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        gt_bin    = cv2.resize(gt_bin,    (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        H, W = new_h, new_w

    ys, xs = np.where(gt_bin > 0)
    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()
    box = np.array([x_min, y_min, x_max, y_max])

    predictor.set_image(image_rgb)
    masks, scores, logits = predictor.predict(
        box=box[None, :],
        point_coords=None,
        point_labels=None,
        multimask_output=False
    )

    pred_mask = masks[0].astype(np.uint8)
    # ensure same shape
    if pred_mask.shape != gt_bin.shape:
        pred_mask = cv2.resize(
            pred_mask,
            (gt_bin.shape[1], gt_bin.shape[0]),
            interpolation=cv2.INTER_NEAREST
        )

    iou  = iou_np(pred_mask, gt_bin)
    dice = dice_np(pred_mask, gt_bin)
    bf   = boundary_f_score_np(pred_mask, gt_bin)

    sam_iou  += iou
    sam_dice += dice
    sam_bf   += bf
    n_used   += 1

sam_iou  /= max(n_used, 1)
sam_dice /= max(n_used, 1)
sam_bf   /= max(n_used, 1)

print("SAM (pretrained) + 1 bbox per image")
print("Tiles used:", n_used)
print(f"IoU:  {sam_iou:.4f}")
print(f"Dice: {sam_dice:.4f}")
print(f"BF:   {sam_bf:.4f}")

SAM bbox test:  11%|█         | 200/1867 [12:49<1:46:57,  3.85s/it]

=== SAM (pretrained) + 1 bbox per image (subset) ===
Tiles used: 190
IoU:  0.1693
Dice: 0.2241
BF:   0.4582





## Fine-Tuned SAM

this SAM is fine-tuned using 1 bbox per-image supervision

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ROOT = "/content/drive/MyDrive/Kuliah/Skripsi S1"
IMAGES_DIR = os.path.join(ROOT, "train/images")
MASKS_DIR  = os.path.join(ROOT, "train/targets")
SAVE_DIR = os.path.join(ROOT, "checkpoints")
CACHE = os.path.join(ROOT, "cache")
os.makedirs(CACHE, exist_ok=True)
print('Device:', DEVICE)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os, glob, random

pre_imgs = sorted(glob.glob(os.path.join(IMAGES_DIR, "*_pre_disaster.png")))
print("Found PRE images:", len(pre_imgs))

pairs = []
for ip in pre_imgs:
    base = os.path.basename(ip)
    mask_name = base.replace("_pre_disaster.png", "_pre_disaster_target.png")
    mp = os.path.join(MASKS_DIR, mask_name)
    if os.path.exists(mp):
        pairs.append((ip, mp))

print("Paired (img,mask):", len(pairs))

In [None]:
def one_image_box(gt):
    gt = (gt > 0).astype(np.uint8)
    if gt.sum() == 0:
        return None

    ys, xs = np.where(gt > 0)
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    return [int(x0), int(y0), int(x1), int(y1)]

In [None]:
class XBDForSAM(Dataset):
    def __init__(self, pairs, processor, skip_empty=True):
        self.samples = pairs
        self.processor = processor
        self.skip_empty = skip_empty

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ip, mp = self.samples[idx]
        image = Image.open(ip).convert("RGB")
        mask  = Image.open(mp).convert("L")
        gt    = (np.array(mask) > 0).astype(np.uint8)

        bbox = one_image_box(gt, mode=self.box_mode)

        if bbox is None:
            if self.skip_empty:
                new_idx = np.random.randint(len(self.samples))
                return self.__getitem__(new_idx)
            bbox = [0, 0, 1, 1]

        x0, y0, x1, y1 = map(float, bbox)
        input_boxes = [[[x0, y0, x1, y1]]]

        inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt")
        inputs = {k: v.squeeze(0) for k,v in inputs.items()}
        inputs["ground_truth_mask"] = gt.astype(np.uint8)
        return inputs

In [None]:
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# ds = XBDForSAM(pairs, processor)

def collate_fn(batch):
    out = {}
    for k in batch[0].keys():
        if k == "ground_truth_mask":
            ms = [torch.from_numpy(b[k]).to(torch.uint8) for b in batch]
            out[k] = torch.stack(ms, dim=0).float()
        else:
            out[k] = torch.stack([b[k] for b in batch], dim=0)
    return out

train_pairs, val_pairs = train_test_split(pairs, test_size=0.1, random_state=42)

train_ds = XBDForSAM(train_pairs, processor)
val_ds   = XBDForSAM(val_pairs, processor)

train_dataloader = DataLoader(
    train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate_fn
)

val_dataloader = DataLoader(
    val_ds,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate_fn
)

# sanity check
b = next(iter(train_dataloader))
print(b["input_boxes"].shape, b["ground_truth_mask"].shape)
for k,v in b.items():
    print(k, tuple(v.shape))
b = next(iter(val_dataloader))
for k,v in b.items():
    print(k, tuple(v.shape))



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry["vit_b"](
    checkpoint=os.path.join(SAVE_DIR, "sam_vit_b_01ec64.pth")
)
sam.to(device)

for p in sam.image_encoder.parameters():
    p.requires_grad = False
for p in sam.prompt_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    sam.mask_decoder.parameters(), lr=1e-5
)

seg_loss = monai.losses.DiceCELoss(
    sigmoid=True, squared_pred=True, reduction="mean"
)

In [None]:
device = "cuda" if torch.cuda.is_available else "cpu"
model = SamModel.from_pretrained("./sam-vit-base").to(device)

ckpt = torch.load(f"{SAVE_DIR}/best.pt", map_location=device)
model.load_state_dict(ckpt["model"], strict=False)
print(f"Loaded fine-tuned weights from epoch {ckpt.get('epoch', '?')} (Dice={ckpt.get('best_dice', '?'):.3f})")

model.eval()

In [None]:
def _norm_logits_shape(pred_masks):
    logits = pred_masks
    while logits.dim() > 4:
        logits = logits.squeeze(1)
    # Ensure channel = 1
    if logits.dim() == 3:
        logits = logits.unsqueeze(1)
    if logits.shape[1] != 1:
        logits = logits[:, :1, ...]
    return logits

def compute_dice_iou(pred, gt, eps=1e-6):
    pred_bin = (pred > 0.5).float()
    inter = (pred_bin * gt).sum(dim=(1,2,3))
    union = (pred_bin + gt - pred_bin*gt).sum(dim=(1,2,3))
    dice = (2*inter + eps) / (pred_bin.sum(dim=(1,2,3)) + gt.sum(dim=(1,2,3)) + eps)
    iou  = (inter + eps) / (union + eps)
    return dice.mean().item(), iou.mean().item()

In [None]:
def compute_dice_iou(pred, gt, eps=1e-6):
    pred_bin = (pred > 0.5).float()
    inter = (pred_bin * gt).sum(dim=(1,2,3))
    union = (pred_bin + gt - pred_bin*gt).sum(dim=(1,2,3))
    dice = (2*inter + eps) / (pred_bin.sum(dim=(1,2,3)) + gt.sum(dim=(1,2,3)) + eps)
    iou  = (inter + eps) / (union + eps)
    return dice.mean().item(), iou.mean().item()

@torch.no_grad()
def evaluate_model(model, dataloader, loss_fn=None):
    model.eval()
    dices, ious, losses = [], [], []
    pbar = tqdm(dataloader, desc="Evaluating", leave=False)

    for batch in pbar:
        pv = batch["pixel_values"].to(device)
        ib = batch["input_boxes"].to(device)
        gt = batch["ground_truth_mask"].to(device).float().unsqueeze(1)

        out = model(pixel_values=pv, input_boxes=ib, multimask_output=False)
        logits_low = _norm_logits_shape(out.pred_masks)
        logits = F.interpolate(logits_low, size=gt.shape[-2:], mode="bilinear", align_corners=False)
        prob = torch.sigmoid(logits)

        dice, iou = compute_dice_iou(prob, gt)
        dices.append(dice)
        ious.append(iou)

        if loss_fn is not None:
            losses.append(loss_fn(logits, gt).item())

        pbar.set_postfix(dice=f"{dice:.3f}", iou=f"{iou:.3f}")

    print(f"\n✅ Mean Dice: {torch.tensor(dices).mean():.4f} | Mean IoU: {torch.tensor(ious).mean():.4f}")
    if losses:
        print(f"Avg Val Loss: {torch.tensor(losses).mean():.4f}")
    return torch.tensor(dices).mean().item(), torch.tensor(ious).mean().item()

In [None]:
evaluate_model(model, val_dataloader, loss_fn=seg_loss)

✅ Mean Dice: 0.3160 | Mean IoU: 0.2199
Avg Val Loss: 0.7576