In [7]:
# LoRA on image encoder is compute-heavy (still runs full ViT forward). It saves trainable params, not forward cost.

# If your domain gap is mostly “mask style,” decoder-only fine-tune is often stronger per compute than LoRA. LoRA shines when you need to shift representation.

In [8]:
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from pycocotools.coco import COCO
from segment_anything.utils.transforms import ResizeLongestSide
import torch.nn.functional as F


class LiveCellSAMDataset(Dataset):
    def __init__(self, img_dir, ann_file, sam_model):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.image_ids = list(self.coco.imgs.keys())

        self.transform = ResizeLongestSide(1024)
        self.pixel_mean = sam_model.pixel_mean
        self.pixel_std = sam_model.pixel_std

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_info = self.coco.imgs[image_id]

        # ---------- load image ----------
        image = Image.open(
            f"{self.img_dir}/{img_info['file_name']}"
        ).convert("RGB")
        image = np.array(image)
        H, W = image.shape[:2]

        # ---------- load annotations ----------
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        anns = self.coco.loadAnns(ann_ids)

        # pick ONE instance
        ann = anns[np.random.randint(len(anns))]

        gt_mask = self.coco.annToMask(ann)      # [H,W]
        x, y, w, h = ann["bbox"]
        box = np.array([[x, y, x + w, y + h]])  # [1,4]

        # ---------- resize ----------
        image = self.transform.apply_image(image)
        gt_mask = self.transform.apply_image(gt_mask)
        box = self.transform.apply_boxes(box, (H, W))

        # ---------- pad to 1024 ----------
        new_h, new_w = image.shape[:2]
        pad_h = 1024 - new_h
        pad_w = 1024 - new_w

        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        gt_mask = torch.from_numpy(gt_mask).unsqueeze(0).float()
        box = torch.from_numpy(box).float()

        image = F.pad(image, (0, pad_w, 0, pad_h))
        gt_mask = F.pad(gt_mask, (0, pad_w, 0, pad_h))

        # ---------- normalize ----------
        #image = (image - self.pixel_mean) / self.pixel_std

        return {
            "image": image,          # [3,1024,1024]
            "mask": gt_mask,         # [1,1024,1024]
            "box": box,              # [1,4]
            "image_id": image_id,
        }


In [11]:
import torch
from torch.utils.data import DataLoader
from segment_anything import sam_model_registry
from tqdm import tqdm

device = "cuda"

sam = sam_model_registry["vit_b"](checkpoint="../sam/sam_vit_b_01ec64.pth")
sam = sam.to(device)

dataset = LiveCellSAMDataset(
    img_dir="../data/livecell/images/train",
    ann_file="../data/livecell/annotations/train.json",
    sam_model=sam
)

loading annotations into memory...
Done (t=5.64s)
creating index...
index created!


In [12]:
from torch.utils.data import DataLoader
from tqdm import tqdm


optimizer = torch.optim.AdamW(
    sam.mask_decoder.parameters(), lr=1e-5
)

loader = DataLoader(dataset, batch_size=1, shuffle=True) # very useful not just for convenience

In [13]:
# define LoRA trainable replacement layer

In [27]:
import math
import torch
import torch.nn as nn


class LoRALinear(nn.Module):
    """
    Wraps a nn.Linear with a low-rank update:
      y = xW^T + b + (alpha/r) * x (BA)^T
    where A: (r, in), B: (out, r)

    you can patch a transformer layer without rewriting the model. This is the core “adapter” idea: wrap instead of rewrite.

    """
    def __init__(self, base: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):
        super().__init__()
        assert isinstance(base, nn.Linear) # will wrap linear layer qkv layer
        self.base = base
        self.r = r # LoRA parameter
        self.alpha = alpha # LoRA parameter
        self.scale = alpha / r if r > 0 else 0.0
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # name format consistent with Torch
        in_f = base.in_features
        out_f = base.out_features

        # Low-rank factors
        self.A = nn.Parameter(torch.zeros(r, in_f))
        self.B = nn.Parameter(torch.zeros(out_f, r))
        # because LoRALinear is a module object in Torch and because we set attribute as Parameter
        # A and B will be visible in the compute graph

        # Init: A ~ Kaiming, B = 0 makes initial update = 0 (safe)
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)
        # wrap should start with BA = 0 and thus the same linear layer as before

        # Freeze base params
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        y = self.base(x)
        # x: (..., in_f)
        # (x @ A^T): (..., r)
        # (..., r) @ B^T: (..., out_f)
        lora = (self.dropout(x) @ self.A.t()) @ self.B.t()
        return y + self.scale * lora


def mark_only_lora_trainable(model: nn.Module):
    for n, p in model.named_parameters():
        p.requires_grad = ("A" in n or "B" in n)


In [14]:
# implement encoder layer replacment

In [26]:
import torch.nn as nn
#from lora import LoRALinear


def _replace_module(parent: nn.Module, child_name: str, new_module: nn.Module):
    setattr(parent, child_name, new_module)

# Why name-based? Because you want to patch specific logical roles:
# qkv: attention query/key/value projection (high leverage)
# proj: attention output projection (also high leverage)

def apply_lora_to_sam_image_encoder(
    sam_model: nn.Module,
    r: int = 8,
    alpha: int = 16,
    dropout: float = 0.0,
    target_keywords=("qkv", "proj"),
):
    """
    Safely replaces selected nn.Linear layers under sam_model.image_encoder with LoRALinear.
    Two-pass approach avoids recursion issues caused by modifying modules during traversal.
    """
    # 1) Collect candidates first (no mutation here)
    to_replace = []
    for parent_name, parent in sam_model.image_encoder.named_modules():
        # parent_name is a dotted path-like string (e.g., "blocks.3.attn")
        # parent is the module object at that path
        for child_name, child in parent.named_children():
            # You want to identify the exact nn.Linear objects that are direct attributes of parent.
            # Replacement must happen at that attribute boundary.
            if isinstance(child, nn.Linear):
                full_name = f"{parent_name}.{child_name}".strip(".")
                # This creates a human-readable identifier like:
                # blocks.5.attn.qkv
                # blocks.5.attn.proj
                if any(k in full_name for k in target_keywords):
                    to_replace.append((parent, child_name, full_name, child))
                    # parent: the object to mutate
                    # child_name: which attribute to set on that parent
                    # full_name: logging only
                    # child: the original linear module (becomes base in LoRALinear)

    # 2) Replace in a second pass
    replaced_names = []
    for parent, child_name, full_name, child in to_replace:
        setattr(parent, child_name, LoRALinear(child, r=r, alpha=alpha, dropout=dropout))
        replaced_names.append(full_name)

    return replaced_names



In [15]:
# make the training into a function, it is OK as well

In [28]:
import torch
import torch.nn.functional as F
from segment_anything import sam_model_registry
# from sam_lora_patch import apply_lora_to_sam_image_encoder
# from lora import mark_only_lora_trainable


def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    probs = probs.flatten(1)
    targets = targets.flatten(1)
    inter = (probs * targets).sum(dim=1)
    union = probs.sum(dim=1) + targets.sum(dim=1)
    dice = (2 * inter + eps) / (union + eps)
    return 1 - dice.mean()


def train_lora_sam(
    base_ckpt: str,
    model_type: str,
    dataloader,
    device="cuda",
    r=8,
    alpha=16,
    dropout=0.0,
    lr=1e-4,
    epochs=1,
):
    sam = sam_model_registry[model_type](checkpoint=base_ckpt)

    # Freeze everything first
    for p in sam.parameters():
        p.requires_grad = False

    # Apply LoRA to image encoder (qkv/proj)
    replaced = apply_lora_to_sam_image_encoder(sam, r=r, alpha=alpha, dropout=dropout)
    print(f"LoRA injected into {len(replaced)} Linear layers")

    sam.to(device)
    sam.train()

    # Ensure only LoRA params train
    mark_only_lora_trainable(sam)

    optim = torch.optim.AdamW([p for p in sam.parameters() if p.requires_grad], lr=lr)
    #Optimizer state is allocated only for trainable params, which is the main memory savings of LoRA.

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            # Expect your batch provides:
            # image: (B,3,1024,1024) float32 normalized as your SAM pipeline expects
            # box: (B,4) in xyxy, 1024-space
            # gt_mask: (B,1,1024,1024) {0,1}
            image = batch["image"].to(device)
            box = batch["box"].to(device)
            gt = batch["mask"].to(device) # this is actuall gt_mask, just naming difference

            # ---- SAM forward (core idea) ----
            # 1) image embeddings from image encoder
            image_embeddings = sam.image_encoder(image)

            # 2) prompt encoder with boxes
            sparse_embeddings, dense_embeddings = sam.prompt_encoder(
                points=None,
                boxes=box,
                masks=None,
            )

            # 3) mask decoder → low-res logits
            low_res_masks, iou_preds = sam.mask_decoder(
                image_embeddings=image_embeddings,
                image_pe=sam.prompt_encoder.get_dense_pe(), #image_pe is positional encoding for the dense grid.
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )

            # 4) Upsample logits to full 1024 for loss
            logits = F.interpolate(low_res_masks, size=gt.shape[-2:], mode="bilinear", align_corners=False)

            # ---- Loss ----
            bce = F.binary_cross_entropy_with_logits(logits, gt)
            dloss = dice_loss(logits, gt)
            loss = bce + dloss
            
            # BCE handles local correctness
            # Dice handles overlap / class imbalance

            optim.zero_grad()
            loss.backward()
            optim.step()
            if step % 100 == 0:
                print(f"step {step}, loss {loss.item():.4f}")

        print(f"Epoch {epoch+1}/{epochs} | loss={loss.item():.4f}")

    return sam


In [25]:
train_lora_sam(
    "../sam/sam_vit_b_01ec64.pth",
    "vit_b",
    loader,
    device="cuda",
    r=8,
    alpha=16,
    dropout=0.0,
    lr=1e-5,
    epochs=1,
)

LoRA injected into 24 Linear layers
step 0, loss 0.4252
step 100, loss 0.3073
step 200, loss 0.7332
step 300, loss 0.2148
step 400, loss 0.1726
step 500, loss 0.3024
step 600, loss 0.3768
step 700, loss 0.1789
step 800, loss 0.1493
step 900, loss 0.1052
step 1000, loss 0.1245
step 1100, loss 0.2540
step 1200, loss 0.0760
step 1300, loss 0.1631
step 1400, loss 0.4814
step 1500, loss 0.1520
step 1600, loss 0.0587
step 1700, loss 0.1204
step 1800, loss 0.1397
step 1900, loss 0.1150
step 2000, loss 0.2016
step 2100, loss 0.2310
step 2200, loss 0.0810
step 2300, loss 0.1263
step 2400, loss 0.0411
step 2500, loss 0.1364
step 2600, loss 0.1685
step 2700, loss 0.4014
step 2800, loss 0.0856
step 2900, loss 0.0611
step 3000, loss 0.0627
step 3100, loss 0.0628
step 3200, loss 0.1153
Epoch 1/1 | loss=0.1826


Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): LoRALinear(
            (base): Linear(in_features=768, out_features=2304, bias=True)
            (dropout): Identity()
          )
          (proj): LoRALinear(
            (base): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Identity()
          )
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bia

In [5]:
import torch

def save_lora_weights(sam_model, path="lora_image_encoder.pt"):
    lora_state = {}
    for name, param in sam_model.named_parameters():
        # LoRA params are named like "...A" and "...B" in our module
        if name.endswith(".A") or name.endswith(".B"):
            lora_state[name] = param.detach().cpu()
    torch.save(lora_state, path)


In [6]:
import torch
from segment_anything import sam_model_registry
#from sam_lora_patch import apply_lora_to_sam_image_encoder

def load_sam_with_lora(base_ckpt, model_type, lora_path, device="cuda", r=8, alpha=16, dropout=0.0):
    sam = sam_model_registry[model_type](checkpoint=base_ckpt)
    apply_lora_to_sam_image_encoder(sam, r=r, alpha=alpha, dropout=dropout)

    lora_state = torch.load(lora_path, map_location="cpu")
    missing, unexpected = sam.load_state_dict(lora_state, strict=False)
    # missing is fine (we are only loading LoRA params)
    sam.to(device)
    sam.eval()
    return sam
