Dataset: LIVECell (microscopy, instance segmentation)
Model: SAM ViT-B
Training scope:

Image encoder: frozen

Prompt encoder: frozen

Mask decoder: trainable

This mirrors how LLMs are adapted in production.

In [1]:
%pip install opencv-python
# This magic install is the last resort due to mixmatch of venv in bash and notebook kernel, no choice

Collecting opencv-python
  Using cached opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl.metadata (19 kB)
Collecting numpy<2.3.0,>=2 (from opencv-python)
  Using cached numpy-2.2.6-cp312-cp312-win_amd64.whl.metadata (60 kB)
Using cached opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl (39.0 MB)
Using cached numpy-2.2.6-cp312-cp312-win_amd64.whl (12.6 MB)
Installing collected packages: numpy, opencv-python
  Attempting uninstall: numpy
    Found existing installation: numpy 2.3.5
    Uninstalling numpy-2.3.5:
      Successfully uninstalled numpy-2.3.5
Successfully installed numpy-2.2.6 opencv-python-4.12.0.88
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
%pip install pycocotools

Collecting pycocotools
  Downloading pycocotools-2.0.11-cp312-abi3-win_amd64.whl.metadata (1.3 kB)
Downloading pycocotools-2.0.11-cp312-abi3-win_amd64.whl (77 kB)
Installing collected packages: pycocotools
Successfully installed pycocotools-2.0.11
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import random

class LiveCellDataset(Dataset):
    def __init__(self, img_dir, ann_file, image_size=512):
        self.coco = COCO(ann_file)
        self.img_dir = img_dir
        self.ids = list(self.coco.imgs.keys())
        self.image_size = image_size

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]

        img = cv2.imread(f"{self.img_dir}/{img_info['file_name']}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        mask = np.zeros(img.shape[:2], dtype=np.uint8)
        ann = random.choice(anns)
        mask = self.coco.annToMask(ann)

        img = cv2.resize(img, (self.image_size, self.image_size))
        mask = cv2.resize(mask, (self.image_size, self.image_size),
                          interpolation=cv2.INTER_NEAREST)

        # generate point prompt inside object
        ys, xs = np.where(mask > 0)
        idx = random.randint(0, len(xs) - 1)
        point = np.array([[xs[idx], ys[idx]]])

        return (
            torch.from_numpy(img).permute(2, 0, 1).float() / 255.0,
            torch.from_numpy(mask).unsqueeze(0).float(),
            torch.from_numpy(point).float()
        )


In [5]:
import torch
import torch.nn.functional as F

def dice_loss(pred, target, eps=1e-6):
    num = 2 * (pred * target).sum()
    den = pred.sum() + target.sum() + eps
    return 1 - num / den

def seg_loss(pred, target):
    pred = torch.sigmoid(pred)
    return dice_loss(pred, target) + F.binary_cross_entropy(pred, target)


In [1]:
from segment_anything import sam_model_registry

ModuleNotFoundError: No module named 'segment_anything'

In [None]:
import torch
from torch.utils.data import DataLoader
from segment_anything import sam_model_registry
# from dataset import LiveCellDataset
# from losses import seg_loss
from tqdm import tqdm

device = "cuda"

sam = sam_model_registry["vit_b"](checkpoint="sam/sam_vit_b_01ec64.pth")
sam.to(device)

# freeze everything except mask decoder
for p in sam.image_encoder.parameters():
    p.requires_grad = False
for p in sam.prompt_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW(
    sam.mask_decoder.parameters(), lr=1e-4
)

dataset = LiveCellDataset(
    img_dir="data/livecell/images/train",
    ann_file="data/livecell/annotations/train.json"
)

loader = DataLoader(dataset, batch_size=4, shuffle=True)

sam.train()

for epoch in range(10):
    pbar = tqdm(loader)
    for img, mask, point in pbar:
        img = img.to(device)
        mask = mask.to(device)
        point = point.to(device)

        with torch.no_grad():
            image_embedding = sam.image_encoder(img)

        sparse, dense = sam.prompt_encoder(
            points=(point, torch.ones(len(point), 1).to(device)),
            boxes=None,
            masks=None
        )

        low_res_masks, _ = sam.mask_decoder(
            image_embedding,
            sam.prompt_encoder.get_dense_pe(),
            sparse,
            dense,
            multimask_output=False
        )

        loss = seg_loss(low_res_masks, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(f"epoch {epoch} | loss {loss.item():.4f}")
