In [None]:
#!/usr/bin/env python3
"""
Demo: Segment Everything with fine-tuned MobileSAM using SamAutomaticMaskGenerator
"""

import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator


def main():
    # 1. Setup device and paths
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    checkpoint_path = '/ssd1/dannyliu/work/MobileSAM-fast-finetuning/logs/best_student.pth'  # Update to your checkpoint path
    image_path = '/ssd1/dannyliu/work/MobileSAM-fast-finetuning/datasets/train/mask/example1/example1.png'         # Update to your image path

    # 2. Load the fine-tuned MobileSAM model
    sam = sam_model_registry['vit_t'](checkpoint=checkpoint_path).to(device)
    sam.eval()

    # 3. Instantiate the automatic mask generator
    mask_generator = SamAutomaticMaskGenerator(sam)

    # 4. Read and preprocess the image
    bgr = cv2.imread(image_path)
    if bgr is None:
        raise FileNotFoundError(f"Image not found: {image_path}")
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    # 5. Generate masks
    masks = mask_generator.generate(rgb)
    print(f"Generated {len(masks)} masks")

    # 6. Visualize masks overlayed on the image
    plt.figure(figsize=(10, 10))
    plt.imshow(rgb)
    for mask in masks:
        seg = mask['segmentation'].astype(np.uint8)
        contours, _ = cv2.findContours(seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        for c in contours:
            plt.plot(c[:, 0, 0], c[:, 0, 1], linewidth=0.5)
    plt.axis('off')
    plt.show()


if __name__ == '__main__':
    main()


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 64 for tensor number 1 in the list.