In [None]:
# # edited from this repository fine-tune-train_segment_anything_2_in_60_lines_of_code
# fine_tune_sam2_from_npymasks.py
# Fine-tune SAM2 from image/mask pairs saved by your app as <stem>_mask.npy

import os
from pathlib import Path
import numpy as np
import torch
import cv2

from contextlib import nullcontext
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image, ImageEnhance


# ----------------- Config -----------------
TRAIN_DIR  = "/Users/sambra/Desktop/training_images/"  # your folder
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff"}
TARGET_SIZE = 512  # resize long side to <=1024 and pad to 1024x1024

# Model/opt
CKPT_PATH = "/Users/sambra/Documents/GitHub/sam2_clone/checkpoints/sam2.1_hiera_tiny.pt"
CFG_PATH  = "/configs/sam2.1/sam2.1_hiera_t.yaml"
DEVICE    = ("cuda" if torch.cuda.is_available()
             else "mps" if torch.backends.mps.is_available() else "cpu")
LR        = 1e-5
WEIGHT_DECAY = 4e-5
ITERS     = 5000
BATCH_SIZE = 8  # keep 1 for simplicity with SAM2 prompt/feature plumbing
SAVE_EVERY = 100
OUT_WEIGHTS = "/Users/sambra/Documents/GitHub/sam2_clone/checkpoints/sam2.1_hiera_l_finetuned.pt"

# -----------------------------------------

def increase_contrast_pil(img_in, factor=5):
    img_pil = img_in if isinstance(img_in, Image.Image) else Image.fromarray(img_in)
    return np.array(ImageEnhance.Contrast(img_pil.convert("RGB")).enhance(factor), dtype=np.uint8)


def _safe_load_npy(fp: str) -> np.ndarray:
    """Load .npy (prefer no pickle; fallback if needed). Expect (N,H,W)."""
    try:
        arr = np.load(fp, allow_pickle=False)
    except ValueError as e:
        if "pickled" not in str(e).lower():
            raise
        # Only load trusted files with pickle!
        arr = np.load(fp, allow_pickle=True)
        print(f"[warn] Loaded with allow_pickle=True: {fp}")
    arr = np.asarray(arr)
    if arr.ndim == 2:  # (H,W) -> (1,H,W)
        arr = arr[None, ...]
    return (arr > 0).astype(np.uint8)


def collect_pairs(root: str):
    """Return list of (image_path, mask_npy_path) where *_mask.npy exists."""
    pairs = []
    for name in os.listdir(root):
        p = Path(root) / name
        if not p.is_file():
            continue
        if p.suffix.lower() in IMAGE_EXTS:
            stem = p.stem
            mask_fp = Path(root) / f"{stem}_mask.npy"
            if mask_fp.exists():
                pairs.append((str(p), str(mask_fp)))
    if not pairs:
        raise RuntimeError(f"No image/mask pairs found in {root}. "
                           f"Expect files like image.png and image_mask.npy.")
    print(f"[data] Found {len(pairs)} image/mask pairs.")
    return pairs


def resize_and_pad_1024(img: np.ndarray, mask: np.ndarray):
    """
    Resize image (H,W,3) and mask (H,W) with same ratio (<=1024) then pad to (1024,1024).
    """
    H, W = img.shape[:2]
    r = min(TARGET_SIZE / W, TARGET_SIZE / H)
    new_w, new_h = int(W * r), int(H * r)

    img_r = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    msk_r = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    # pad to TARGET_SIZE square
    pad_bottom = TARGET_SIZE - new_h
    pad_right  = TARGET_SIZE - new_w

    if pad_bottom > 0:
        img_r = np.concatenate(
            [img_r, np.zeros([pad_bottom, new_w, 3], dtype=np.uint8)], axis=0
        )
        msk_r = np.concatenate(
            [msk_r, np.zeros([pad_bottom, new_w], dtype=np.uint8)], axis=0
        )
    if pad_right > 0:
        img_r = np.concatenate(
            [img_r, np.zeros([TARGET_SIZE, pad_right, 3], dtype=np.uint8)], axis=1
        )
        msk_r = np.concatenate(
            [msk_r, np.zeros([TARGET_SIZE, pad_right], dtype=np.uint8)], axis=1
        )
    return img_r, msk_r


def sample_single(pairs):
    """
    Pick one image/mask pair, choose a random instance from the stack,
    return (Img_1024x1024, mask_1024x1024, [[x,y]]) with a positive point.
    """
    # pick random pair
    img_fp, npy_fp = pairs[np.random.randint(len(pairs))]

    # load image (BGR->RGB)
    Img = cv2.imread(img_fp)
    if Img is None:
        raise RuntimeError(f"Failed to read image: {img_fp}")
    Img = Img[..., ::-1]

    # load masks stack (N,H,W); choose one non-empty slice
    ms = _safe_load_npy(npy_fp)  # (N,H,W) uint8
    if ms.size == 0 or ms.shape[0] == 0:
        # no masks; resample
        return sample_single(pairs)

    # pick random non-empty mask
    non_empty = [i for i in range(ms.shape[0]) if ms[i].any()]
    if not non_empty:
        return sample_single(pairs)
    ind = non_empty[np.random.randint(len(non_empty))]
    mask = (ms[ind] > 0).astype(np.uint8)

    # resize & pad
    Img, mask = resize_and_pad_1024(Img, mask)

    # positive point inside mask
    coords = np.argwhere(mask > 0)
    if coords.size == 0:
        return sample_single(pairs)
    yx = coords[np.random.randint(len(coords))]
    # SAM expects [[x,y]]
    input_point = [[int(yx[1]), int(yx[0])]]
    return Img, mask, input_point


def read_batch(pairs, batch_size=BATCH_SIZE):
    images, masks, pts = [], [], []
    for _ in range(batch_size):
        I, M, P = sample_single(pairs)
        images.append(I)
        masks.append(M)
        pts.append(P)
    # labels: 1 for positive points
    labels = np.ones([batch_size, 1], dtype=np.int64)
    masks = np.asarray(masks, dtype=np.uint8)  # (B, H, W)
    pts = np.asarray(pts, dtype=np.int64)      # (B, 1, 2)
    return images, masks, pts, labels


# ---------------- Build / Prepare model ----------------
sam = build_sam2(CFG_PATH, CKPT_PATH, device=DEVICE)
predictor = SAM2ImagePredictor(sam)

# trainable
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
predictor.model.image_encoder.train(False)  # NOTE: if the repo froze grads with no_grad(), remove those

optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
use_amp = (DEVICE == "cuda")
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

# --------------- Data -------------------
pairs = collect_pairs(TRAIN_DIR)

# --------------- Train loop -------------
mean_iou = 0.0
for itr in range(ITERS):
    images, masks_np, points_np, labels_np = read_batch(pairs, batch_size=BATCH_SIZE)
    images = [increase_contrast_pil(image, factor=5) for image in images]
    if masks_np.shape[0] == 0:
        continue

    # Single-sample forward for stability with predictor internals
    # (keeps close to the original 60-line recipe)
    total_loss = 0.0
    total_iou  = 0.0

    for b in range(len(images)):
        with (torch.autocast("cuda", dtype=torch.bfloat16) if use_amp else nullcontext()):
            # Encode image
            predictor.set_image(images[b])  # (H,W,3) uint8

            # Prepare prompts
            mask_input, unnorm_coords, labels_t, _ = predictor._prep_prompts(
                points_np[b:b+1], labels_np[b:b+1], box=None, mask_logits=None, normalize_coords=True
            )
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels_t), boxes=None, masks=None
            )

            # Decode masks
            # features prepared by set_image()
            high_res_features = predictor._features["high_res_feats"]  # list of tensors
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"],
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=True,
                high_res_features=high_res_features,
            )
            # Upscale to original (post-transform) resolution
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            # Losses
            gt_mask = torch.tensor(masks_np[b:b+1].astype(np.float32), device=DEVICE)  # (1,H,W)
            pred_prob = torch.sigmoid(prd_masks[:, 0])  # (1,H,W)

            # BCE
            seg_loss = (-gt_mask * torch.log(pred_prob + 1e-5)
                        - (1 - gt_mask) * torch.log(1 - pred_prob + 1e-5)).mean()

            # IoU score loss
            inter = (gt_mask * (pred_prob > 0.5)).sum((1, 2))
            union = gt_mask.sum((1, 2)) + (pred_prob > 0.5).sum((1, 2)) - inter
            iou = inter / (union + 1e-6)
            score_loss = torch.abs(prd_scores[:, 0].to(DEVICE) - iou).mean()

            loss = seg_loss + 0.05 * score_loss

        predictor.model.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.detach().cpu())
        total_iou  += float(iou.detach().cpu().mean())

    mean_iou = (0.99 * mean_iou) + (0.01 * (total_iou / len(images)))
    if itr % 10 == 0:
        print(f"step {itr:06d} | loss {total_loss/len(images):.4f} | mean IoU {mean_iou:.4f}")

    if itr % SAVE_EVERY == 0 and itr > 0:
        #torch.save(predictor.model.state_dict(), OUT_WEIGHTS)
        torch.save({"model": predictor.model.state_dict()}, OUT_WEIGHTS)

# final save
torch.save({"model": predictor.model.state_dict()}, OUT_WEIGHTS)
print("Saved:", OUT_WEIGHTS)


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


[data] Found 88 image/mask pairs.
step 000000 | loss 0.3114 | mean IoU 0.0047
step 000010 | loss 0.0151 | mean IoU 0.0708
step 000020 | loss 0.0087 | mean IoU 0.1349
step 000030 | loss 0.0078 | mean IoU 0.1950
step 000040 | loss 0.0041 | mean IoU 0.2498
step 000050 | loss 0.0046 | mean IoU 0.3021
step 000060 | loss 0.0080 | mean IoU 0.3435
step 000070 | loss 0.0084 | mean IoU 0.3870
step 000080 | loss 0.0044 | mean IoU 0.4255
step 000090 | loss 0.0035 | mean IoU 0.4592
step 000100 | loss 0.0056 | mean IoU 0.4938
step 000110 | loss 0.0067 | mean IoU 0.5225
step 000120 | loss 0.0054 | mean IoU 0.5443
step 000130 | loss 0.0100 | mean IoU 0.5625
step 000140 | loss 0.0051 | mean IoU 0.5867
step 000150 | loss 0.0073 | mean IoU 0.6092
step 000160 | loss 0.0073 | mean IoU 0.6286
step 000170 | loss 0.0097 | mean IoU 0.6414
step 000180 | loss 0.0030 | mean IoU 0.6562
step 000190 | loss 0.0055 | mean IoU 0.6696
step 000200 | loss 0.0032 | mean IoU 0.6826
step 000210 | loss 0.0047 | mean IoU 0.688

KeyboardInterrupt: 

In [9]:
import numpy as np
import torch

from contextlib import nullcontext
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

import numpy as np
import torch

from sam2 import SAM2AutomaticMaskGenerator

# --- New: utilities for multi-seed training ---

def load_pair(img_fp: str, mask_fp: str):
    """Return RGB image (H,W,3) uint8 and masks (N,H,W) uint8."""
    import cv2
    img = cv2.imread(img_fp)
    if img is None:
        raise RuntimeError(f"Failed to read image: {img_fp}")
    img = img[..., ::-1]  # BGR->RGB
    ms = _safe_load_npy(mask_fp)  # (N,H,W) uint8 0/1
    return img, ms

def resize_and_pad_square(img, masks, target=TARGET_SIZE):
    """Resize long side to target and pad to (target,target). Masks: (N,H,W)."""
    import cv2
    H, W = img.shape[:2]
    r = min(target / W, target / H)
    new_w, new_h = int(W * r), int(H * r)
    img_r = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    masks_r = np.stack(
        [cv2.resize(m.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST) for m in masks],
        axis=0
    ) if masks.size else masks

    # pad to square
    pad_b, pad_r = target - new_h, target - new_w
    if pad_b > 0:
        img_r = np.concatenate([img_r, np.zeros([pad_b, new_w, 3], np.uint8)], 0)
        masks_r = np.concatenate([masks_r, np.zeros([masks_r.shape[0], pad_b, new_w], np.uint8)], 1)
    if pad_r > 0:
        img_r = np.concatenate([img_r, np.zeros([target, pad_r, 3], np.uint8)], 1)
        masks_r = np.concatenate([masks_r, np.zeros([masks_r.shape[0], target, pad_r], np.uint8)], 2)
    return img_r, masks_r

def sample_points_per_instance(masks: np.ndarray, k: int = 2, rng: np.random.Generator = None):
    """
    For each non-empty instance mask, pick up to k random interior pixels.
    Return:
      pts_all: (P, 1, 2) int [[x,y]] per prompt
      gt_idx:  (P,) int -> which mask each prompt should learn
    """
    if rng is None:
        rng = np.random.default_rng()
    PTS, IDX = [], []
    for i, m in enumerate(masks):
        ys, xs = np.nonzero(m)
        if len(xs) == 0:
            continue
        sel = rng.choice(len(xs), size=min(k, len(xs)), replace=(len(xs) < k))
        for s in np.atleast_1d(sel):
            PTS.append([[int(xs[s]), int(ys[s])]])  # [[x,y]]
            IDX.append(i)
    if not PTS:
        return np.zeros((0,1,2), np.int64), np.zeros((0,), np.int64)
    return np.asarray(PTS, np.int64), np.asarray(IDX, np.int64)

# --------------- Data list -------------------
pairs = collect_pairs(TRAIN_DIR)  # unchanged

# --------------- Train loop -------------
mean_iou = 0.0
rng = np.random.default_rng()

for itr in range(ITERS):
    # minibatch of images; we will process images one-by-one (set_image per image)
    batch_pairs = [pairs[rng.integers(len(pairs))] for _ in range(BATCH_SIZE)]

    total_loss = 0.0
    total_iou  = 0.0
    total_prompts = 0

    for (img_fp, msk_fp) in batch_pairs:
        # --- load & prep data ---
        img, masks_stack = load_pair(img_fp, msk_fp)            # (H,W,3), (N,H,W)
        img = increase_contrast_pil(img, factor=5)
        img, masks_stack = resize_and_pad_square(img, masks_stack, target=TARGET_SIZE)

        # sample multiple seeds per instance (e.g., k=2)
        pts_all, gt_idx = sample_points_per_instance(masks_stack, k=2, rng=rng)
        if pts_all.shape[0] == 0:
            continue  # no instances

        # --- forward ---
        use_amp = (DEVICE == "cuda")
        with (torch.autocast("cuda", dtype=torch.bfloat16) if use_amp else nullcontext()):
            predictor.set_image(img)  # one image encoding

            # prompts for P seeds
            # shapes expected by predictor internals: (P,1,2) and (P,1)
            labels_np = np.ones((pts_all.shape[0], 1), dtype=np.int64)  # all positive clicks
            mask_input, unnorm_coords, labels_t, _ = predictor._prep_prompts(
                pts_all, labels_np, box=None, mask_logits=None, normalize_coords=True
            )

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels_t), boxes=None, masks=None
            )

            # decode a SINGLE proposal per prompt (easier loss)
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"],
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,   # (P, C)
                dense_prompt_embeddings=dense_embeddings,     # (P, C, H', W')
                multimask_output=False,                       # <â€” one mask per seed
                repeat_image=True,
                high_res_features=predictor._features["high_res_feats"],
            )
            prd_masks = predictor._transforms.postprocess_masks(
                low_res_masks, predictor._orig_hw[-1]
            )  # (P,1,H,W)

            # --- losses across all prompts for THIS image ---
            P = prd_masks.shape[0]
            gt = torch.from_numpy(masks_stack[gt_idx].astype(np.float32)).to(DEVICE)  # (P,H,W)
            pred_prob = torch.sigmoid(prd_masks[:, 0])                                # (P,H,W)

            # BCE loss
            eps = 1e-5
            seg_loss = (-gt * torch.log(pred_prob + eps) - (1 - gt) * torch.log(1 - pred_prob + eps)).mean()

            # IoU calibration (pred score vs hard IoU of binarized pred)
            pred_bin = (pred_prob > 0.5).float()
            inter = (gt * pred_bin).sum((1, 2))
            union = gt.sum((1, 2)) + pred_bin.sum((1, 2)) - inter + 1e-6
            iou = inter / union
            score_loss = torch.abs(prd_scores[:, 0].to(DEVICE) - iou).mean()

            loss = seg_loss + 0.05 * score_loss

        predictor.model.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.detach().cpu())
        total_iou  += float(iou.detach().cpu().mean())
        total_prompts += P

    if total_prompts == 0:
        continue

    mean_iou = 0.99 * mean_iou + 0.01 * (total_iou / max(1, len(batch_pairs)))

    if itr % 10 == 0:
        print(f"step {itr:06d} | loss {total_loss/max(1,len(batch_pairs)):.4f} | mean IoU {mean_iou:.4f}")

    if itr % SAVE_EVERY == 0 and itr > 0:
        torch.save({"model": predictor.model.state_dict()}, OUT_WEIGHTS)

# final save
torch.save({"model": predictor.model.state_dict()}, OUT_WEIGHTS)
print("Saved:", OUT_WEIGHTS)


ImportError: cannot import name 'SAM2AutomaticMaskGenerator' from 'sam2' (/Users/sambra/miniforge3/envs/sam2_env/lib/python3.10/site-packages/sam2/__init__.py)

In [7]:
import numpy as np
import torch

from contextlib import nullcontext
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

def increase_contrast_pil(img_in, factor=5):
    img_pil = img_in if isinstance(img_in, Image.Image) else Image.fromarray(img_in)
    return np.array(ImageEnhance.Contrast(img_pil.convert("RGB")).enhance(factor), dtype=np.uint8)

def touches_border(m):
    # m is a boolean/uint8 mask HxW
    return m[0,:].any() or m[-1,:].any() or m[:,0].any() or m[:,-1].any()

def filter_background_like(anns, max_area_frac=0.6, drop_edge_touch=False, max_bbox_cover=0.98):
    out = []
    Aimg = H * W
    for ann in anns:
        m = ann["segmentation"]           # HxW bool/uint8
        area = int(ann.get("area", m.sum()))
        x, y, w, h = ann.get("bbox", [0,0, W, H])  # xywh in SAM
        # 1) remove masks that are too large
        if area > max_area_frac * Aimg:
            continue
        # 2) remove masks whose bbox nearly covers the image
        if (w / W) > max_bbox_cover or (h / H) > max_bbox_cover:
            continue
        # 3) optionally remove masks touching the image border
        if drop_edge_touch and touches_border(m):
            continue
        out.append(ann)
    return out

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    H, W = anns[0]["segmentation"].shape[:2]
    ax = plt.gca()
    ax.set_autoscale_on(False)

    # base: do NOT overwrite after drawing masks
    # (call plt.imshow(image_np) BEFORE calling show_anns)

    # build a single RGBA overlay
    overlay = np.zeros((H, W, 4), dtype=np.float32)
    for ann in sorted(anns, key=lambda x: x.get("area", 0), reverse=True):
        m = ann["segmentation"].astype(bool)
        color = np.random.rand(3)
        rgba = np.r_[color, [0.5]]  # alpha 0.5
        overlay[m] = rgba

        if borders:
            import cv2
            cnts, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            cnts = [cv2.approxPolyDP(c, epsilon=0.01, closed=True) for c in cnts]
            cv2.drawContours(overlay, cnts, -1, (0, 0, 1, 0.4), thickness=1)

    ax.imshow(overlay)  # draw on top


device = "mps" if torch.backends.mps.is_available() else "cpu"

In [8]:
ckpt = "/Users/sambra/Documents/GitHub/sam2_clone/checkpoints/sam2.1_hiera_l_finetuned.pt"
cfg  = "/configs/sam2.1/sam2.1_hiera_t.yaml"

img_path = "/Users/sambra/Desktop/training_images/MeOH_plate_A2-5945.tif"  # change this to your file
img = np.array(Image.open(img_path).convert("RGB"))

# turn off the CUDA-only post-processing:
sam2 = build_sam2(cfg, ckpt, device=device, apply_postprocessing=False) # must be false otherwise cuda throws a paddy
predictor = SAM2ImagePredictor(sam2)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    pred_iou_thresh=0.2,
)

image_pil = Image.open(img_path).convert("RGB")
image_np  = increase_contrast_pil(image_pil, factor=5)   # HWC uint8

masks = mask_generator.generate(image_np)

H, W = image_np.shape[:2]
masks = filter_background_like(masks, max_area_frac=0.6, drop_edge_touch=True)

plt.figure(figsize=(20, 20))
plt.imshow(image_np)     # base
show_anns(masks)         # overlays on top
plt.axis('off')
plt.show()

print("raw proposals:", len(masks))
print("after filter:", len(masks))

NameError: name 'SAM2AutomaticMaskGenerator' is not defined

In [None]:
ckpt = "/Users/sambra/Documents/GitHub/sam2_clone/checkpoints/sam2.1_hiera_tiny.pt"
cfg  = "/configs/sam2.1/sam2.1_hiera_t.yaml"

img_path = "/Users/sambra/Desktop/training_images/MeOH_plate_A2-5945.tif"  # change this to your file
img = np.array(Image.open(img_path).convert("RGB"))

# turn off the CUDA-only post-processing:
sam2 = build_sam2(cfg, ckpt, device=device, apply_postprocessing=False) # must be false otherwise cuda throws a paddy
predictor = SAM2ImagePredictor(sam2)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    pred_iou_thresh=0.2,
)

image_pil = Image.open(img_path).convert("RGB")
image_np  = increase_contrast_pil(image_pil, factor=5)   # HWC uint8

masks = mask_generator.generate(image_np)

H, W = image_np.shape[:2]
masks = filter_background_like(masks, max_area_frac=0.6, drop_edge_touch=True)

plt.figure(figsize=(20, 20))
plt.imshow(image_np)     # base
show_anns(masks)         # overlays on top
plt.axis('off')
plt.show()

print("raw proposals:", len(masks))
print("after filter:", len(masks))