# Q2 — Text-Driven Image Segmentation with SAM 2

Pipeline: Load image -> accept text prompt -> convert text to region seeds (via CLIPSeg) -> feed seeds to SAM 2 -> display final mask overlay.

This notebook is designed to run on Colab (GPU). If SAM 2 API or checkpoints change, adjust the install/checkpoint cell accordingly.

In [None]:
%pip -q install opencv-python matplotlib pillow einops --upgrade
# Attempt to install SAM 2 (adjust if the official package/repo name changes)
%pip -q install git+https://github.com/facebookresearch/segment-anything-2.git || echo 'If this fails, update to the latest official SAM 2 install instructions.'
%pip -q install clipseg


In [None]:
import os, cv2, torch, numpy as np, matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T

# CLIPSeg imports
try:
    import clipseg
    from clipseg.models.clipseg import CLIPDensePredT
except Exception as e:
    print('CLIPSeg not available:', e)
    CLIPDensePredT = None

# SAM 2 imports (placeholder; update if the API changes)
try:
    import sam2  # placeholder for SAM 2 package
except Exception as e:
    print('SAM 2 not available yet. Please update install cell to latest official instructions.', e)
    sam2 = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


In [None]:
def load_image(path_or_url):
    if path_or_url.startswith('http://') or path_or_url.startswith('https://'):
        import requests, io
        img = Image.open(io.BytesIO(requests.get(path_or_url).content)).convert('RGB')
    else:
        img = Image.open(path_or_url).convert('RGB')
    return img

def show_overlay(img_pil, mask_np, alpha=0.6):
    img = np.array(img_pil)
    mask = (mask_np > 0).astype(np.uint8)
    color = np.array([30, 144, 255], dtype=np.uint8)
    overlay = img.copy()
    overlay[mask==1] = (alpha*color + (1-alpha)*overlay[mask==1]).astype(np.uint8)
    plt.figure(figsize=(6,6)); plt.imshow(overlay); plt.axis('off'); plt.show()

def text_to_heatmap_clipseg(img_pil, text):
    assert CLIPDensePredT is not None, 'CLIPSeg not installed'
    model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
    model.eval(); model.to(device)
    model.load_state_dict(torch.hub.load_state_dict_from_url('https://huggingface.co/timojl/clipseg/resolve/main/rd64-uni.pth', map_location=device))
    tf = T.Compose([T.Resize((352,352)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
    x = tf(img_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = model(x, [text])[0]  # (1, 1, H, W)
        heat = torch.sigmoid(preds).squeeze().cpu().numpy()
    # resize back to image size
    heat = cv2.resize(heat, img_pil.size, interpolation=cv2.INTER_LINEAR)
    return heat

def heatmap_to_point_seeds(heat, k=10, thresh=0.5):
    pts = []
    h, w = heat.shape
    mask = (heat >= thresh).astype(np.uint8)
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        # fallback: take top-k maxima
        flat_idx = np.argpartition(heat.flatten(), -k)[-k:]
        ys, xs = np.unravel_index(flat_idx, (h,w))
    for y,x in zip(ys, xs):
        pts.append((int(x), int(y)))
    # sub-sample if too many
    if len(pts) > k:
        idx = np.linspace(0, len(pts)-1, k).astype(int)
        pts = [pts[i] for i in idx]
    return pts


In [None]:
# Provide either a URL or upload a local image in Colab
image_url = ''  # e.g., 'https://images.cocodataset.org/val2017/000000039769.jpg'
text_prompt = 'a dog'  # edit your prompt

img = None
if image_url:
    img = load_image(image_url)
else:
    try:
        from google.colab import files
        up = files.upload()
        if len(up):
            fname = list(up.keys())[0]
            img = load_image(fname)
    except Exception:
        pass
assert img is not None, 'Please provide an image URL or upload an image.'
plt.figure(figsize=(6,6)); plt.imshow(img); plt.axis('off'); plt.title('Input Image'); plt.show()

# text->heatmap and seeds
heat = text_to_heatmap_clipseg(img, text_prompt)
plt.figure(figsize=(6,6)); plt.imshow(heat, cmap='magma'); plt.colorbar(); plt.title('CLIPSeg heatmap'); plt.axis('off'); plt.show()
seeds = heatmap_to_point_seeds(heat, k=10, thresh=0.6)
print('Seeds (x,y):', seeds[:10])


In [None]:
# SAM 2 segmentation: placeholder API (update with the official predictor and checkpoint)
if sam2 is None:
    print('SAM 2 is not installed. Please update the install cell with the official package and checkpoint.')
    # As a minimal fallback, visualize thresholded heatmap directly
    mask = (heat >= 0.6).astype(np.uint8)
    show_overlay(img, mask)
else:
    # TODO: Replace with official SAM 2 model loading and prediction using point prompts (seeds)
    # Example (pseudo-code):
    # checkpoint_path = '/content/sam2_checkpoint.pth'  # provide valid path or download
    # predictor = sam2.SAM2Predictor(checkpoint=checkpoint_path)
    # points = np.array([[x,y] for (x,y) in seeds], dtype=np.int32)
    # labels = np.ones((points.shape[0],), dtype=np.int32)
    # mask = predictor.predict_with_points(image=np.array(img), points=points, labels=labels)
    # show_overlay(img, mask)
    print('Please wire up the official SAM 2 predictor with the seeds. Showing heatmap fallback.')
    mask = (heat >= 0.6).astype(np.uint8)
    show_overlay(img, mask)
