# Сегментация по текстовому промпту с Grounding DINO + SAM

В этом ноутбуке мы реализуем пайплайн сегментации объектов по текстовому описанию, используя две мощные модели:
- **Grounding DINO** - для детекции объектов по текстовому промпту
- **SAM (Segment Anything Model)** - для точной сегментации найденных объектов


## Установка зависимостей

Устанавливаем необходимые библиотеки для работы с Grounding DINO и SAM.

**Важно:** Для работы Grounding DINO требуются скомпилированные C++ расширения. Если возникают ошибки компиляции, убедитесь, что установлены необходимые инструменты компиляции (gcc, g++, CUDA toolkit при использовании GPU).


In [1]:
!pip install -q segment-anything transformers torch torchvision opencv-python
!git clone https://github.com/IDEA-Research/GroundingDINO.git 2>/dev/null || echo "GroundingDINO already exists"
import os
os.chdir('GroundingDINO')
!pip install -q -e . || echo "Installation completed"
os.chdir('..')



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
GroundingDINO already exists
[33m  DEPRECATION: Legacy editable install of groundingdino==0.1.0 from file:///home/tam2511/mounts/0/arcadia/market/robotics/cv/ml/user_data/shad/cv2025/lesson3/seminar/GroundingDINO (setup.py develop) is deprecated. pip 25.3 will enforce this behaviour change. A possible replacement is to add a pyproject.toml or enable --use-pep517, and use setuptools >= 64. If the resulting installation is not behaving as expected, try using --config-settings editable_mode=compat. Please consult the setuptools documentation for more information. Discussion can be found at https://github.com/pypa/pip/issues/11457[0m[33m
[0m    [1;31merror[0m: [1msubprocess-exited-with-error[0m
    
    [31m×[0m [32mpython setup.py 

## Импорты и настройка

Импортируем необходимые библиотеки для работы с изображениями, моделями и визуализацией.


In [2]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pathlib import Path
import os
import cv2

torch.manual_seed(42)
np.random.seed(42)


## Проверка и компиляция C++ расширений

Проверяем наличие C++ расширений Grounding DINO и при необходимости компилируем их.


In [3]:
import os
import subprocess
import glob

groundingdino_dir = 'GroundingDINO'
cpp_ext_pattern = os.path.join(groundingdino_dir, 'groundingdino', 'models', 'GroundingDINO', 'ms_deform_attn.cpython-*.so')
so_files = glob.glob(cpp_ext_pattern)

if not so_files:
    print("C++ extensions not found. Attempting to compile...")
    original_dir = os.getcwd()
    try:
        os.chdir(groundingdino_dir)
        result = subprocess.run(['python', 'setup.py', 'build_ext', '--inplace'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            print("C++ extensions compiled successfully")
        else:
            print(f"Warning: Compilation had issues. Check output:")
            print(result.stderr[:500])
            print("\nModel may still work but could be slower.")
    except Exception as e:
        print(f"Warning: Could not compile C++ extensions: {e}")
        print("Model may still work but could be slower.")
    finally:
        os.chdir(original_dir)
else:
    print(f"C++ extensions found: {so_files[0]}")


C++ extensions not found. Attempting to compile...
/usr/bin/nvcc: 3: exec: /usr/lib/nvidia-cuda-toolkit/bin/nvcc: not found
Traceback (most recent call last):
  File "/home/tam2511/mounts/0/arcadia/market/robotics/cv/ml/user_data/shad/cv2025/lesson3/semin

Model may still work but could be slower.


## Загрузка моделей

Инициализируем модели Grounding DINO и SAM. Grounding DINO будет использоваться для детекции объектов по текстовому промпту, а SAM - для получения масок сегментации.


In [4]:
import sys
import subprocess
import warnings

sys.path.insert(0, 'GroundingDINO')

warnings.filterwarnings('ignore', category=UserWarning)

from groundingdino.util.inference import load_model, load_image, predict, annotate
from groundingdino.util import box_ops
import groundingdino.datasets.transforms as T
from segment_anything import sam_model_registry, SamPredictor

groundingdino_config_path = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
groundingdino_checkpoint_path = "groundingdino_swint_ogc.pth"

sam_checkpoint_path = "sam_vit_h_4b8939.pth"
sam_model_type = "vit_h"

device = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists(groundingdino_checkpoint_path):
    print("Downloading Grounding DINO checkpoint...")
    subprocess.check_call(['wget', '-q', 'https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth', '-O', groundingdino_checkpoint_path])

if not os.path.exists(sam_checkpoint_path):
    print("Downloading SAM checkpoint...")
    subprocess.check_call(['wget', '-q', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-O', sam_checkpoint_path])

os.environ['USE_REENTRANT'] = 'False'

grounding_model = load_model(groundingdino_config_path, groundingdino_checkpoint_path)
grounding_model = grounding_model.to(device)
grounding_model.eval()

sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint_path)
sam = sam.to(device)
sam_predictor = SamPredictor(sam)




final text_encoder_type: bert-base-uncased




## Загрузка датасета CamVid

Загружаем изображения из датасета CamVid для тестирования пайплайна сегментации.


In [5]:
class CamVidLoader:
    def __init__(self, data_dir, split='train'):
        self.data_dir = data_dir
        self.split = split
        txt_file = os.path.join(data_dir, split + '.txt')
        self.image_paths = []
        
        with open(txt_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    parts = line.split()
                    if len(parts) >= 1:
                        img_path_abs = parts[0]
                        img_filename = os.path.basename(img_path_abs)
                        img_path = os.path.join(data_dir, split, img_filename)
                        if os.path.exists(img_path):
                            self.image_paths.append(img_path)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        return np.array(image), img_path

data_dir = '../data'
dataset = CamVidLoader(data_dir, split='val')
print(f"Loaded {len(dataset)} images from validation set")


Loaded 101 images from validation set


## Функция детекции объектов через Grounding DINO

Grounding DINO принимает изображение и текстовый промпт, возвращает координаты bounding box'ов обнаруженных объектов.


In [6]:
def detect_objects(image_source, text_prompt, box_threshold=0.3, text_threshold=0.25):
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    if isinstance(image_source, str):
        image, _ = load_image(image_source)
    elif isinstance(image_source, np.ndarray):
        image = Image.fromarray(image_source).convert('RGB')
    else:
        image = image_source
    
    image_transformed, _ = transform(image, None)
    
    boxes, logits, phrases = predict(
        model=grounding_model,
        image=image_transformed,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )
    
    return boxes, logits, phrases


## Функция сегментации через SAM

SAM принимает изображение и координаты bounding box'ов, возвращает маски сегментации для каждого объекта.


In [7]:
def segment_objects(image, boxes):
    sam_predictor.set_image(image)
    
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    
    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image.shape[:2])
    transformed_boxes = transformed_boxes.to(device)
    
    masks, scores, logits = sam_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    
    return masks.cpu().numpy(), scores.cpu().numpy()


## Пайплайн сегментации по текстовому промпту

Объединяем Grounding DINO и SAM в единый пайплайн: сначала детектируем объекты по текстовому промпту, затем сегментируем их.


In [8]:
def segment_by_text_prompt(image_source, text_prompt, box_threshold=0.3, text_threshold=0.25):
    if isinstance(image_source, str):
        image = np.array(Image.open(image_source).convert('RGB'))
    else:
        image = image_source
    
    boxes, logits, phrases = detect_objects(image, text_prompt, box_threshold, text_threshold)
    
    if len(boxes) == 0:
        return image, None, None, None
    
    masks, scores = segment_objects(image, boxes)
    
    return image, boxes, masks, phrases


## Визуализация результатов

Функция для визуализации исходного изображения с наложенными масками сегментации и bounding box'ами.


In [9]:
def visualize_results(image, boxes, masks, phrases, alpha=0.5):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(image)
    if boxes is not None and len(boxes) > 0:
        H, W = image.shape[:2]
        boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
        
        for i, (box, phrase) in enumerate(zip(boxes_xyxy, phrases)):
            x1, y1, x2, y2 = box.numpy()
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, 
                                   edgecolor='red', facecolor='none')
            axes[1].add_patch(rect)
            axes[1].text(x1, y1-5, phrase, fontsize=10, color='red', 
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    axes[1].set_title('Detections (Grounding DINO)')
    axes[1].axis('off')
    
    axes[2].imshow(image)
    if masks is not None and len(masks) > 0:
        for i, mask in enumerate(masks):
            mask_bool = mask[0].astype(bool)
            color = np.random.rand(3)
            colored_mask = np.zeros_like(image)
            colored_mask[mask_bool] = color
            axes[2].imshow(colored_mask, alpha=alpha)
    axes[2].set_title('Segmentation Masks (SAM)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


## Тестирование на примерах из CamVid

Тестируем пайплайн на нескольких изображениях из датасета CamVid с различными текстовыми промптами.


In [10]:
test_indices = [0, 5, 10]
text_prompts = [
    "car. vehicle. automobile",
    "road. street. pavement",
    "tree. vegetation. plant",
    "building. house. structure",
    "person. pedestrian. human"
]

for idx in test_indices:
    image, img_path = dataset[idx]
    print(f"\nProcessing image: {os.path.basename(img_path)}")
    
    for prompt in text_prompts[:2]:
        print(f"  Prompt: {prompt}")
        result_image, boxes, masks, phrases = segment_by_text_prompt(
            image, 
            prompt,
            box_threshold=0.3,
            text_threshold=0.25
        )
        
        if boxes is not None and len(boxes) > 0:
            print(f"    Found {len(boxes)} objects")
            visualize_results(result_image, boxes, masks, phrases)
        else:
            print(f"    No objects found for this prompt")



Processing image: 0016E5_07959.png
  Prompt: car. vehicle. automobile




NameError: name '_C' is not defined

## Сравнение с ground truth

Сравниваем результаты сегментации по текстовому промпту с ground truth масками из датасета CamVid.


In [None]:
def load_ground_truth_mask(img_path, data_dir):
    img_filename = os.path.basename(img_path)
    mask_path = os.path.join(data_dir, 'valannot', img_filename)
    if os.path.exists(mask_path):
        mask = np.array(Image.open(mask_path).convert('L'))
        return mask
    return None

def compare_with_ground_truth(image, masks, gt_mask, prompt):
    if gt_mask is None:
        return
    
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(gt_mask, cmap='tab20', vmin=0, vmax=11)
    axes[1].set_title('Ground Truth Mask')
    axes[1].axis('off')
    
    combined_mask = np.zeros_like(gt_mask)
    if masks is not None and len(masks) > 0:
        for mask in masks:
            mask_resized = cv2.resize(mask[0].astype(np.uint8), (gt_mask.shape[1], gt_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
            combined_mask = np.logical_or(combined_mask, mask_resized.astype(bool))
    
    axes[2].imshow(combined_mask.astype(int), cmap='gray')
    axes[2].set_title(f'Predicted Mask (Prompt: {prompt})')
    axes[2].axis('off')
    
    overlay = image.copy()
    if image.shape[:2] != combined_mask.shape:
        combined_mask = cv2.resize(combined_mask.astype(np.uint8), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool)
    overlay[combined_mask] = overlay[combined_mask] * 0.6 + np.array([1, 0, 0]) * 0.4
    axes[3].imshow(overlay)
    axes[3].set_title('Overlay')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()

for idx in test_indices[:1]:
    image, img_path = dataset[idx]
    gt_mask = load_ground_truth_mask(img_path, data_dir)
    
    prompt = "car. vehicle"
    result_image, boxes, masks, phrases = segment_by_text_prompt(
        image, 
        prompt,
        box_threshold=0.3,
        text_threshold=0.25
    )
    
    if gt_mask is not None:
        compare_with_ground_truth(result_image, masks, gt_mask, prompt)
