### Step 1: Initialize SAM2


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )


In [2]:
np.random.seed(42)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [3]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "pretrained_models/sam2.1_hiera_large.pt" # your own SAM2 path
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

### Step 2: Load ReferSegDataset

In [None]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))  # 添加父目录到系统路径

from refer_seg_dataset import ReferSegDataset
dataset_name = "refcocog"
dataset = ReferSegDataset(base_image_dir="your own data dir", refer_seg_data=dataset_name, data_split="train")

In [8]:
refer_seg_ds = dataset.refer_seg_data[dataset_name]
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]

In [None]:
idx = 7
image_info = images[idx]
image_path = image_info["file_name"]
image_id = image_info["id"]
image_info

In [None]:
refs = img2refs[image_id]
refs[0]

In [None]:
# set Image to SAM2
import cv2 
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
plt.show()


In [12]:
from scipy import ndimage
def get_two_representative_points(m):
    """
    找到两个能较好描述mask形状的点
    
    Args:
        m: 二值图像数组
    
    Returns:
        tuple: ((x1, y1), (x2, y2)) 两个代表性点的坐标
    """
    y_indices, x_indices = np.where(m == 1)
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None, None
    
    # 计算距离变换
    dist_transform = ndimage.distance_transform_edt(m)
    
    # 找到第一个点（全局最大值点）
    y1, x1 = np.unravel_index(dist_transform.argmax(), dist_transform.shape)
    
    # 计算mask的重心
    center_y = int(np.mean(y_indices))
    center_x = int(np.mean(x_indices))
    
    # 将点分为两组：距离第一个点较远的点和较近的点
    points = np.column_stack((y_indices, x_indices))
    distances_to_first = ((points[:, 0] - y1) ** 2 + (points[:, 1] - x1) ** 2) ** 0.5
    
    # 找到距离第一个点最远的点集
    far_points = points[distances_to_first > np.median(distances_to_first)]
    
    if len(far_points) > 0:
        # 在远点中找到距离变换值最大的点作为第二个点
        far_dist_values = dist_transform[far_points[:, 0], far_points[:, 1]]
        second_point_idx = np.argmax(far_dist_values)
        y2, x2 = far_points[second_point_idx]
    else:
        # 如果没有合适的远点，使用重心附近的点
        local_region = dist_transform[
            max(0, center_y - 10):min(m.shape[0], center_y + 10),
            max(0, center_x - 10):min(m.shape[1], center_x + 10)
        ]
        local_y, local_x = np.unravel_index(local_region.argmax(), local_region.shape)
        y2 = local_y + max(0, center_y - 10)
        x2 = local_x + max(0, center_x - 10)
    
    # 确保两个点都在mask上
    if m[y1, x1] == 0:
        distances = (x_indices - x1)**2 + (y_indices - y1)**2
        nearest_idx = np.argmin(distances)
        x1, y1 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])
    
    if m[y2, x2] == 0:
        distances = (x_indices - x2)**2 + (y_indices - y2)**2
        nearest_idx = np.argmin(distances)
        x2, y2 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])
    
    return [x1, y1], [x2, y2] 

In [13]:
def get_mask_from_point(predictor, input_point, input_label, box):
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=box,
        multimask_output=False,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]
    return masks

In [14]:
import numpy as np

def compute_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

### Step 3: Generate annotation list

In [None]:
from pycocotools import mask
import numpy as np
from tqdm import tqdm  # 导入tqdm
import json  # 导入json模块
import cv2

threshold_iou = 0.6  # threshold_iou IOU:  0.659445961
cnt = 0

seg_zero_annotation_list = []

for idx in tqdm(range(len(images)), desc="Processing images"):  # 使用tqdm包装循环
    image_info = images[idx]
    image_path = image_info["file_name"]
    image_id = image_info["id"]
    refs = img2refs[image_id]
    
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)
    
    texts = []
    bboxes = []
    points = []
    ann_ids = []
    for ref in refs:
        ann_id = ref["ann_id"]
        
        text = ref["sentences"][0]["raw"].strip().strip(".?!").lower()
        
        ann = annotations[ann_id]
        if len(ann["segmentation"]) == 0:
            m = np.zeros((image_info["height"], image_info["width"])).astype(
                np.uint8
            )
            continue

        if type(ann["segmentation"][0]) == list:  # polygon
            rle = mask.frPyObjects(
                ann["segmentation"], image_info["height"], image_info["width"]
            )
        else:
            rle = ann["segmentation"]
            for i in range(len(rle)):
                if not isinstance(rle[i]["counts"], bytes):
                    rle[i]["counts"] = rle[i]["counts"].encode()
        m = mask.decode(rle)
        m = np.sum(
            m, axis=2
        )  # sometimes there are multiple binary map (corresponding to multiple segs)
        m = m.astype(np.uint8)  # convert to np.uint8 
        
        left = np.where(m == 1)[1].min()
        top = np.where(m == 1)[0].min()
        right = np.where(m == 1)[1].max()
        bottom = np.where(m == 1)[0].max()
        box = [left, top, right, bottom]
        
        points_1, points_2 = get_two_representative_points(m)
        
        point = points_1
        label = 1
        
        mask_pred = get_mask_from_point(predictor, np.array([point]), np.array([label]), np.array(box))
        
        mask_pred = mask_pred[0].astype(bool)
        mask_gt = m.astype(bool)
        iou = compute_iou(mask_pred, mask_gt)
        # print(iou)
        # show_image_with_mask_and_bbox_point(raw_image, mask_pred, box, points, labels)
        if iou < threshold_iou:
            continue
        
        bboxes.append(box)
        points.append(point)
        texts.append(text)
        ann_ids.append(str(ann_id))
    
    if len(bboxes) == 0:
        continue
    
    seg_zero_annotation_list.append({
        "id": f"{dataset_name}_" + "_".join(ann_ids[:3]),
        "image_id": image_id,
        "image_path": image_path,
        "problem": "'" + "' and '".join(texts) + "'",
        "bboxes": bboxes,
        "center_points": points
    })
        
    cnt += 1
        
    if cnt > 20:
        break
        
            
print(f"Total: {len(seg_zero_annotation_list)}")



In [None]:
seg_zero_annotation_list[10]

In [None]:
for item in seg_zero_annotation_list:
    item['bboxes'] = [list(map(int, bbox)) for bbox in item['bboxes']]
    item['center_points'] = [list(map(int, center_point)) for center_point in item['center_points']]

In [None]:
seg_zero_annotation_list[10]

### Step 4: Save and show examples

In [16]:
with open(f'seg_zero_{dataset_name}_annotation_list.json', 'w', encoding='utf-8') as f:
    json.dump(seg_zero_annotation_list, f, ensure_ascii=False, indent=4)

In [None]:
import cv2 

item = seg_zero_annotation_list[30]

print(item['problem'])
print(item['bboxes'])
print(item['center_points'])

image_path = item['image_path']
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


for bbox, center_point in zip(item['bboxes'], item['center_points']):
    cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
    cv2.circle(image, (center_point[0], center_point[1]), 5, (0, 255, 0), -1)
    
plt.imshow(image)
plt.show()


### Step 5: Please refer to gen_training_dataset.py