In [None]:
!nvidia-smi
!pip -q install opencv-python pycocotools tqdm

In [None]:
import os, json, random, math, time
import numpy as np
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from shapely import wkt
import cv2, os
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-svolwcv4
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-svolwcv4
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=4b2d195e112f595baae2e9892d5f3cf2affbfd67b33a9cf12ad0853d641adf8d
  Stored in directory: /tmp/pip-ephem-wheel-cache-et2lai1n/wheels/29/82/ff/04e2be9805a1cb48bec0b85b5a6da6b63f647645750a0e42d4
Successfully built segment_anything
Installing collected packages: segment_anything
  Attempting 

In [None]:
!rm -rf /content/segment-anything
!git clone -q https://github.com/facebookresearch/segment-anything.git /content/segment-anything
!pip -q install -e /content/segment-anything
!pip -q install shapely

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

IMAGE_DIR = "/content/drive/MyDrive/Kuliah/Skripsi S1/train/images"
LABEL_DIR = "/content/drive/MyDrive/Kuliah/Skripsi S1/train/labels"

SAM_CKPT  = "/content/drive/MyDrive/Kuliah/Skripsi S1/checkpoints/sam_vit_b_01ec64.pth"
MODEL_TYPE = "vit_b"

OUT_DIR = "/content/drive/MyDrive/Kuliah/Skripsi S1/checkpoints"
os.makedirs(OUT_DIR, exist_ok=True)

EPOCHS = 20
BATCH_SIZE = 2
LR = 1e-4
WEIGHT_DECAY = 0.0
GRAD_ACCUM_STEPS = 8
AMP = True
INSTANCES_PER_IMAGE = 32


VAL_RATIO = 0.1
SEED = 42

DEVICE: cpu
OUT_DIR: /content/drive/MyDrive/Kuliah/Skripsi S1/checkpoints


In [None]:
def list_images(image_dir):
    exts = (".png")
    files = [f for f in os.listdir(image_dir) if f.lower().endswith(exts)]
    files.sort()
    return files

def compute_bbox_from_mask(bin_mask):
    ys, xs = np.where(bin_mask > 0)
    if len(xs) == 0 or len(ys) == 0:
        return None
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    return np.array([x0, y0, x1, y1], dtype=np.float32)

def polygon_to_mask(poly, h, w):
    mask = np.zeros((h, w), dtype=np.uint8)

    ext = np.array(list(poly.exterior.coords), dtype=np.float32)
    ext[:, 0] = np.clip(ext[:, 0], 0, w - 1)
    ext[:, 1] = np.clip(ext[:, 1], 0, h - 1)
    ext_i = np.round(ext).astype(np.int32)
    cv2.fillPoly(mask, [ext_i], 1)

    for ring in poly.interiors:
        hole = np.array(list(ring.coords), dtype=np.float32)
        hole[:, 0] = np.clip(hole[:, 0], 0, w - 1)
        hole[:, 1] = np.clip(hole[:, 1], 0, h - 1)
        hole_i = np.round(hole).astype(np.int32)
        cv2.fillPoly(mask, [hole_i], 0)

    return mask


In [None]:
class XBDWKTInstanceDataset(Dataset):
    def __init__(self, files, image_dir, label_dir, instances_per_image=16):
        self.files = files
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.instances_per_image = instances_per_image

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        base = os.path.splitext(fname)[0]

        img_path = os.path.join(self.image_dir, fname)
        json_path = os.path.join(self.label_dir, base + ".json")

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        if not os.path.exists(json_path):
            raise FileNotFoundError(f"Missing JSON: {json_path}")

        with open(json_path, "r") as f:
            data = json.load(f)

        feats = data.get("features", {}).get("xy", [])  # ✅ FIXED
        masks = []

        for feat in feats:
            wkt_str = feat.get("wkt", None)
            if not wkt_str:
                continue

            geom = wkt.loads(wkt_str)

            if geom.geom_type == "Polygon":
                m = polygon_to_mask(geom, h, w)
                if m.sum() > 0:
                    masks.append(m)

            elif geom.geom_type == "MultiPolygon":
                for poly in geom.geoms:
                    m = polygon_to_mask(poly, h, w)
                    if m.sum() > 0:
                        masks.append(m)

        if len(masks) == 0:
            gt = np.zeros((h, w), dtype=np.uint8)
            bbox = np.array([0, 0, w-1, h-1], dtype=np.float32)
            return [(img, gt, bbox)]

        k = min(self.instances_per_image, len(masks))
        chosen = random.sample(masks, k)

        samples = []
        for gt in chosen:
            bbox = compute_bbox_from_mask(gt)
            if bbox is None:
                continue
            samples.append((img, gt, bbox))

        if len(samples) == 0:
            gt = np.zeros((h, w), dtype=np.uint8)
            bbox = np.array([0, 0, w-1, h-1], dtype=np.float32)
            samples = [(img, gt, bbox)]

        return samples

def collate_flatten(batch):
    flat = []
    for item in batch:
        flat.extend(item if isinstance(item, list) else [item])
    imgs, gts, bboxes = zip(*flat)
    return list(imgs), list(gts), torch.tensor(np.stack(bboxes), dtype=torch.float32)


In [None]:
all_files = list_images(IMAGE_DIR)

PRE_KEYWORD = ['pre_disaster']

pre_files = []
for f in all_files:
  name = f.lower()
  if any(k in name for k in PRE_KEYWORD) and ("post" not in name):
        pre_files.append(f)

print("All images:", len(all_files))
print("pre only:", len(pre_files))

random.seed(SEED)
random.shuffle(pre_files)

n_val = int(len(pre_files) * VAL_RATIO)
val_files = pre_files[:n_val]
train_files = pre_files[n_val:]

train_ds = XBDWKTInstanceDataset(train_files, IMAGE_DIR, LABEL_DIR, instances_per_image=INSTANCES_PER_IMAGE)
val_ds   = XBDWKTInstanceDataset(val_files,   IMAGE_DIR, LABEL_DIR, instances_per_image=INSTANCES_PER_IMAGE)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_flatten
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_flatten
)

print("Total images:", len(pre_files))
print("Train images:", len(train_ds))
print("Val images:", len(val_ds))


All images: 5598
pre only: 2799
Total images: 2799
Train images: 2520
Val images: 279


In [None]:
fname = "guatemala-volcano_00000002_pre_disaster.png"
base = os.path.splitext(fname)[0]
json_path = os.path.join(LABEL_DIR, base + ".json")

print("JSON path:", json_path)
print("Exists:", os.path.exists(json_path))

with open(json_path, "r") as f:
    data = json.load(f)

print("Top-level keys:", data.keys())
print("xy exists:", "xy" in data)
print("Number of features:", len(data["features"]["xy"]))

feat_list = data["features"]["xy"]
print("type:", type(feat_list))
print("len:", len(feat_list))

feat = feat_list[0]
print("Feature keys:", feat.keys())
print("WKT head:", feat["wkt"][:200])

JSON path: /content/drive/MyDrive/Kuliah/Skripsi S1/train/labels/guatemala-volcano_00000002_pre_disaster.json
Exists: True
Top-level keys: dict_keys(['features', 'metadata'])
xy exists: False
Number of features: 1
type: <class 'list'>
len: 1
Feature keys: dict_keys(['properties', 'wkt'])
WKT head: POLYGON ((1024 238.0106127841054, 1020.035908691126 233.848316909788, 1016.666431078584 237.2177945223307, 1011.909521507935 232.4608849516822, 1008.341839329949 236.2267716951123, 1001.40467953942 23


In [None]:
geom = wkt.loads(feat["wkt"])
print("geom type:", geom.geom_type)
print("bounds:", geom.bounds)

img = cv2.imread(os.path.join(IMAGE_DIR, fname))
h, w = img.shape[:2]
print("image w,h:", w, h)

geom type: Polygon
bounds: (983.368064084044, 230.0824301663579, 1024.0, 335.1360018036964)
image w,h: 1024 1024


In [None]:
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CKPT).to(DEVICE)
sam.train()

# Freeze image encoder + prompt encoder (cheaper)
for p in sam.image_encoder.parameters():
    p.requires_grad = False
for p in sam.prompt_encoder.parameters():
    p.requires_grad = False

# Train mask decoder
for p in sam.mask_decoder.parameters():
    p.requires_grad = True

transform = ResizeLongestSide(sam.image_encoder.img_size)

optimizer = torch.optim.AdamW(
    sam.mask_decoder.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY
)

scaler = torch.cuda.amp.GradScaler(enabled=AMP)

print("SAM loaded:", MODEL_TYPE)


SAM loaded: vit_b


  scaler = torch.cuda.amp.GradScaler(enabled=AMP)


In [None]:
bce = nn.BCEWithLogitsLoss()

def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    probs = probs.flatten(1)
    targets = targets.flatten(1)
    num = 2 * (probs * targets).sum(dim=1)
    den = (probs + targets).sum(dim=1) + eps
    return 1 - (num / den).mean()

@torch.no_grad()
def dice_iou_from_logits(logits, targets, thresh=0.5, eps=1e-6):
    probs = torch.sigmoid(logits)
    preds = (probs > thresh).float()

    preds = preds.flatten(1)
    targets = targets.flatten(1)

    inter = (preds * targets).sum(dim=1)
    union = (preds + targets - preds * targets).sum(dim=1)

    dice = (2 * inter + eps) / (preds.sum(dim=1) + targets.sum(dim=1) + eps)
    iou  = (inter + eps) / (union + eps)
    return dice.mean().item(), iou.mean().item()


In [None]:
def preprocess_image_np(img_rgb):
    resized = transform.apply_image(img_rgb)
    t = torch.as_tensor(resized, device=DEVICE).permute(2, 0, 1).contiguous()[None, ...]
    t = sam.preprocess(t)
    resized_hw = resized.shape[:2]
    return t, resized_hw

def preprocess_mask_np(mask_uint8):
    return torch.from_numpy(mask_uint8.astype(np.float32))[None, None, ...]  # 1x1xH0xW0

@torch.no_grad()
def evaluate_sam(sam_model, loader, max_instances=800):
    sam_model.eval()
    dice_list, iou_list = [], []
    count = 0

    for imgs, gts, bboxes in tqdm(loader, desc="Evaluating", leave=False):
        for i in range(len(imgs)):
            img = imgs[i]
            gt  = gts[i]
            bbox = bboxes[i].cpu().numpy()

            original_size = img.shape[:2]
            input_image, resized_hw = preprocess_image_np(img)

            image_embedding = sam_model.image_encoder(input_image)

            box = transform.apply_boxes(bbox[None, :].astype(np.float32), original_size)
            box_torch = torch.as_tensor(box, dtype=torch.float32, device=DEVICE)

            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None, boxes=box_torch, masks=None
            )

            low_res_masks, _ = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )

            up_masks = sam_model.postprocess_masks(
                low_res_masks,
                input_size=resized_hw,
                original_size=original_size,
            )

            gt_t = preprocess_mask_np(gt).to(DEVICE)

            d, j = dice_iou_from_logits(up_masks, gt_t)
            dice_list.append(d)
            iou_list.append(j)

            count += 1
            if count >= max_instances:
                break
        if count >= max_instances:
            break

    sam_model.train()
    return float(np.mean(dice_list)) if dice_list else 0.0, float(np.mean(iou_list)) if iou_list else 0.0


In [None]:
best_dice = -1.0
global_step = 0

for epoch in range(1, EPOCHS + 1):
    sam.train()
    running_loss = 0.0
    seen = 0

    optimizer.zero_grad(set_to_none=True)

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for imgs, gts, bboxes in pbar:
        total_loss_this_batch = 0.0

        for i in range(len(imgs)):
            img = imgs[i]
            gt  = gts[i]
            bbox = bboxes[i].cpu().numpy()

            original_size = img.shape[:2]

            input_image, resized_hw = preprocess_image_np(img)
            with torch.no_grad():
                image_embedding = sam.image_encoder(input_image)

            box = transform.apply_boxes(bbox[None, :].astype(np.float32), original_size)
            box_torch = torch.as_tensor(box, dtype=torch.float32, device=DEVICE)

            with torch.no_grad():
                sparse_embeddings, dense_embeddings = sam.prompt_encoder(
                    points=None, boxes=box_torch, masks=None
                )

            gt_t = preprocess_mask_np(gt).to(DEVICE)

            with torch.cuda.amp.autocast(enabled=AMP):
                low_res_masks, _ = sam.mask_decoder(
                    image_embeddings=image_embedding,
                    image_pe=sam.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )

                up_masks = sam.postprocess_masks(
                    low_res_masks,
                    input_size=resized_hw,
                    original_size=original_size,
                )

                loss = 0.5 * bce(up_masks, gt_t) + 0.5 * dice_loss(up_masks, gt_t)

            loss = loss / GRAD_ACCUM_STEPS
            scaler.scale(loss).backward()
            total_loss_this_batch += loss.item()

        if (global_step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        global_step += 1
        running_loss += total_loss_this_batch
        seen += 1

        avg_loss = running_loss / max(seen, 1)
        pbar.set_postfix({"loss": f"{avg_loss:.4f}"})

    epoch_loss = running_loss / max(seen, 1)
    print(f"\nEpoch {epoch} Train loss: {epoch_loss:.6f}")

    # Evaluate
    val_dice, val_iou = evaluate_sam(sam, val_loader, max_instances=800)
    print(f"Epoch {epoch} VAL Dice: {val_dice:.4f} | VAL IoU: {val_iou:.4f}")

    if val_dice > best_dice:
        best_dice = val_dice
        save_path = os.path.join(OUT_DIR, "sam_xbd_maskdecoder_best.pth")
        torch.save({
            "epoch": epoch,
            "model_type": MODEL_TYPE,
            "sam_ckpt": SAM_CKPT,
            "mask_decoder_state_dict": sam.mask_decoder.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_dice": best_dice,
            "val_iou": val_iou,
        }, save_path)
        print("✅ Saved best:", save_path)

print("\nDONE. Best Val Dice:", best_dice)


  with torch.cuda.amp.autocast(enabled=AMP):
  with torch.cuda.amp.autocast(enabled=AMP):
  with torch.cuda.amp.autocast(enabled=AMP):
Epoch 1/20: 100%|██████████| 1260/1260 [1:29:46<00:00,  4.27s/it, loss=0.2487]



Epoch 1 Train loss: 0.248732




Epoch 1 VAL Dice: 0.9016 | VAL IoU: 0.8248
✅ Saved best: /content/drive/MyDrive/Kuliah/Skripsi S1/checkpoints/sam_xbd_maskdecoder_best.pth


Epoch 2/20: 100%|██████████| 1260/1260 [1:29:16<00:00,  4.25s/it, loss=0.2289]



Epoch 2 Train loss: 0.228868




Epoch 2 VAL Dice: 0.9063 | VAL IoU: 0.8326
✅ Saved best: /content/drive/MyDrive/Kuliah/Skripsi S1/checkpoints/sam_xbd_maskdecoder_best.pth


Epoch 3/20: 100%|██████████| 1260/1260 [1:29:54<00:00,  4.28s/it, loss=0.2242]



Epoch 3 Train loss: 0.224161




Epoch 3 VAL Dice: 0.9022 | VAL IoU: 0.8261


Epoch 4/20: 100%|██████████| 1260/1260 [1:29:40<00:00,  4.27s/it, loss=0.2202]



Epoch 4 Train loss: 0.220219




Epoch 4 VAL Dice: 0.9041 | VAL IoU: 0.8296


Epoch 5/20:   3%|▎         | 43/1260 [03:14<57:49,  2.85s/it, loss=0.2232]  