Dataset: LIVECell (microscopy, instance segmentation)
Model: SAM ViT-B
Training scope:

Image encoder: frozen

Prompt encoder: frozen

Mask decoder: trainable

This mirrors how LLMs are adapted in production.

In [2]:
#%pip install opencv-python
#%pip install pycocotools
# This magic install is the last resort due to mixmatch of venv in bash and notebook kernel, no choice
# this issue is fixed by force install into the absolute path
# use bash not cmd, not ps, just bash

In [1]:
import torch
import cv2 # so confusing, pip install opencv but module name is cv2
import numpy as np
from torch.utils.data import Dataset
from pycocotools.coco import COCO # install dataset
import random

class LiveCellDataset(Dataset):
    def __init__(self, img_dir, ann_file, image_size=512):
        self.coco = COCO(ann_file)
        self.img_dir = img_dir
        self.ids = list(self.coco.imgs.keys())
        self.image_size = image_size

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]

        img = cv2.imread(f"{self.img_dir}/{img_info['file_name']}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        mask = np.zeros(img.shape[:2], dtype=np.uint8)
        ann = random.choice(anns)
        mask = self.coco.annToMask(ann)

        img = cv2.resize(img, (self.image_size, self.image_size))
        mask = cv2.resize(mask, (self.image_size, self.image_size),
                          interpolation=cv2.INTER_NEAREST)

        # generate point prompt inside object
        ys, xs = np.where(mask > 0)
        idx = random.randint(0, len(xs) - 1)
        point = np.array([[xs[idx], ys[idx]]])

        return (
            torch.from_numpy(img).permute(2, 0, 1).float() / 255.0,
            torch.from_numpy(mask).unsqueeze(0).float(),
            torch.from_numpy(point).float()
        )


In [2]:
import torch
import torch.nn.functional as F

def dice_loss(pred, target, eps=1e-6):
    num = 2 * (pred * target).sum()
    den = pred.sum() + target.sum() + eps
    return 1 - num / den

def seg_loss(pred, target):
    pred = torch.sigmoid(pred)
    return dice_loss(pred, target) + F.binary_cross_entropy(pred, target)


In [3]:
import torch
from torch.utils.data import DataLoader
from segment_anything import sam_model_registry
from dataset import LiveCellDataset
from losses import seg_loss
from tqdm import tqdm

device = "cuda"

sam = sam_model_registry["vit_b"](checkpoint="sam/sam_vit_b_01ec64.pth")
sam.to(device)

# freeze everything except mask decoder
for p in sam.image_encoder.parameters():
    p.requires_grad = False
for p in sam.prompt_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW(
    sam.mask_decoder.parameters(), lr=1e-4
)

dataset = LiveCellDataset(
    img_dir="data/livecell/images/train",
    ann_file="data/livecell/annotations/train.json"
)

loader = DataLoader(dataset, batch_size=4, shuffle=True)

sam.train()

for epoch in range(10):
    pbar = tqdm(loader)
    for img, mask, point in pbar:
        img = img.to(device)
        mask = mask.to(device)
        point = point.to(device)

        with torch.no_grad():
            image_embedding = sam.image_encoder(img)

        sparse, dense = sam.prompt_encoder(
            points=(point, torch.ones(len(point), 1).to(device)),
            boxes=None,
            masks=None
        )

        low_res_masks, _ = sam.mask_decoder(
            image_embedding,
            sam.prompt_encoder.get_dense_pe(),
            sparse,
            dense,
            multimask_output=False
        )

        loss = seg_loss(low_res_masks, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(f"epoch {epoch} | loss {loss.item():.4f}")
        break


loading annotations into memory...
Done (t=5.50s)
creating index...
index created!


  0%|          | 0/814 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 2