# Open-GroundingDino Libero 微调模型可视化验证

本notebook用于验证在Libero数据集上微调后的GroundingDINO模型效果。

## 功能
1. 单张图像检测与可视化
2. 批量可视化
3. 阈值敏感性分析
4. 与Ground Truth对比
5. 保存预测结果


In [None]:
import os
import sys
import json
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
from IPython.display import display

# 设置工作目录
SCRIPT_DIR = Path(os.getcwd()).resolve()
if str(SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(SCRIPT_DIR))

print(f"Working directory: {SCRIPT_DIR}")


In [None]:
# 导入Open-GroundingDINO相关模块 (使用本地模块，非原版groundingdino)
import datasets.transforms as T
from util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap


## 1. 配置路径


In [None]:
# ============================================
# 配置路径 - 根据需要修改
# ============================================

# 项目根目录
PROJECT_ROOT = SCRIPT_DIR.parent.parent

# 模型配置文件
CONFIG_FILE = SCRIPT_DIR / "config" / "cfg_odvg.py"

# 微调后的模型权重路径 (修改为你的实际路径)
CHECKPOINT_PATH = PROJECT_ROOT / "checkpoints" / "open_gdino_finetuned" / "checkpoint_best_regular.pth"
# 或者使用最新的checkpoint
# CHECKPOINT_PATH = PROJECT_ROOT / "checkpoints" / "open_gdino_finetuned" / "checkpoint.pth"

# 数据目录
DATA_ROOT = PROJECT_ROOT / "data_processed" / "open_gdino_dataset"
IMAGE_ROOT = DATA_ROOT / "images"

# 验证路径是否存在
print(f"Config file: {CONFIG_FILE} (exists: {CONFIG_FILE.exists()})")
print(f"Checkpoint: {CHECKPOINT_PATH} (exists: {CHECKPOINT_PATH.exists()})")
print(f"Image root: {IMAGE_ROOT} (exists: {IMAGE_ROOT.exists()})")


## 2. 工具函数


In [None]:
def load_image(image_path):
    """加载并预处理图像"""
    image_pil = Image.open(image_path).convert("RGB")
    
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image_tensor, _ = transform(image_pil, None)
    return image_pil, image_tensor


def load_model(config_path, checkpoint_path, device="cuda"):
    """加载模型 (仅构建推理所需的模型，跳过criterion和postprocessors)"""
    from models.GroundingDINO.backbone import build_backbone
    from models.GroundingDINO.transformer import build_transformer
    from models.GroundingDINO.groundingdino import GroundingDINO
    
    args = SLConfig.fromfile(str(config_path))
    args.device = device
    
    # 直接构建模型（不需要criterion和postprocessors）
    backbone = build_backbone(args)
    transformer = build_transformer(args)
    
    model = GroundingDINO(
        backbone,
        transformer,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
        iter_update=True,
        query_dim=4,
        num_feature_levels=args.num_feature_levels,
        nheads=args.nheads,
        dec_pred_bbox_embed_share=args.dec_pred_bbox_embed_share,
        two_stage_type=args.two_stage_type,
        two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
        two_stage_class_embed_share=args.two_stage_class_embed_share,
        num_patterns=args.num_patterns,
        dn_number=0,
        dn_box_noise_scale=args.dn_box_noise_scale,
        dn_label_noise_ratio=args.dn_label_noise_ratio,
        dn_labelbook_size=args.dn_labelbook_size,
        text_encoder_type=args.text_encoder_type,
        sub_sentence_present=args.sub_sentence_present,
        max_text_len=args.max_text_len,
    )
    
    # weights_only=False for PyTorch 2.6+ compatibility
    checkpoint = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
    
    # 处理checkpoint格式
    if "model" in checkpoint:
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint
    
    load_res = model.load_state_dict(clean_state_dict(state_dict), strict=False)
    print(f"Model loaded: {load_res}")
    
    model = model.to(device)
    model.eval()
    return model


def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cuda"):
    """
    运行模型推理
    
    Args:
        model: GroundingDINO模型
        image: 预处理后的图像tensor
        caption: 文本提示 (如 "bowl. plate. drawer.")
        box_threshold: 边界框置信度阈值
        text_threshold: 文本匹配阈值
        device: 设备
    
    Returns:
        boxes_filt: 过滤后的边界框 [N, 4] (normalized xywh)
        pred_phrases: 预测的短语列表
        scores: 置信度分数
    """
    caption = caption.lower().strip()
    if not caption.endswith("."):
        caption = caption + "."
    
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    
    logits = outputs["pred_logits"].sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"][0]  # (nq, 4)
    
    # 过滤低置信度结果
    logits_filt = logits.cpu().clone()
    boxes_filt = boxes.cpu().clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]
    boxes_filt = boxes_filt[filt_mask]
    
    # 获取预测短语
    tokenizer = model.tokenizer
    tokenized = tokenizer(caption)
    
    pred_phrases = []
    scores = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer)
        score = logit.max().item()
        pred_phrases.append(pred_phrase)
        scores.append(score)
    
    return boxes_filt, pred_phrases, scores


In [None]:
# 定义一组颜色用于可视化
COLORS = [
    (255, 0, 0),      # Red
    (0, 255, 0),      # Green
    (0, 0, 255),      # Blue
    (255, 255, 0),    # Yellow
    (255, 0, 255),    # Magenta
    (0, 255, 255),    # Cyan
    (255, 128, 0),    # Orange
    (128, 0, 255),    # Purple
    (0, 128, 255),    # Light Blue
    (255, 0, 128),    # Pink
]

def plot_boxes_to_image(image_pil, boxes, labels, scores=None, show_scores=True):
    """
    在图像上绘制边界框和标签
    
    Args:
        image_pil: PIL图像
        boxes: 边界框 [N, 4] (normalized xywh format)
        labels: 标签列表
        scores: 置信度分数列表 (可选)
        show_scores: 是否显示分数
    
    Returns:
        绘制后的PIL图像
    """
    image_draw = image_pil.copy()
    draw = ImageDraw.Draw(image_draw)
    W, H = image_pil.size
    
    # 尝试加载字体
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
    except:
        font = ImageFont.load_default()
    
    for i, (box, label) in enumerate(zip(boxes, labels)):
        # 转换坐标: normalized xywh -> pixel xyxy
        cx, cy, w, h = box.tolist()
        x0 = int((cx - w/2) * W)
        y0 = int((cy - h/2) * H)
        x1 = int((cx + w/2) * W)
        y1 = int((cy + h/2) * H)
        
        # 获取颜色
        color = COLORS[i % len(COLORS)]
        
        # 绘制边界框
        draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
        
        # 准备标签文本
        if scores is not None and show_scores:
            text = f"{label} ({scores[i]:.2f})"
        else:
            text = label
        
        # 绘制标签背景和文本
        if hasattr(font, "getbbox"):
            bbox = draw.textbbox((x0, y0), text, font=font)
        else:
            tw, th = draw.textsize(text, font=font)
            bbox = (x0, y0 - th, x0 + tw, y0)
        
        # 调整标签位置到框上方
        text_h = bbox[3] - bbox[1]
        label_y = max(0, y0 - text_h - 4)
        
        draw.rectangle([x0, label_y, bbox[2] - bbox[0] + x0 + 4, label_y + text_h + 4], fill=color)
        draw.text((x0 + 2, label_y + 2), text, fill="white", font=font)
    
    return image_draw


def visualize_detection(image_path, caption, model, box_threshold=0.3, text_threshold=0.25, 
                        device="cuda", figsize=(12, 8), show_scores=True):
    """
    完整的检测和可视化流程
    """
    # 加载图像
    image_pil, image_tensor = load_image(image_path)
    
    # 运行推理
    boxes, phrases, scores = get_grounding_output(
        model, image_tensor, caption, box_threshold, text_threshold, device
    )
    
    # 绘制结果
    image_with_boxes = plot_boxes_to_image(image_pil, boxes, phrases, scores, show_scores)
    
    # 显示
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    axes[0].imshow(image_pil)
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    
    axes[1].imshow(image_with_boxes)
    axes[1].set_title(f"Detection Result\nPrompt: {caption}")
    axes[1].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    # 打印检测结果
    print(f"\nDetected {len(boxes)} objects:")
    for i, (phrase, score) in enumerate(zip(phrases, scores)):
        print(f"  {i+1}. {phrase}: {score:.4f}")
    
    return boxes, phrases, scores


## 3. 加载模型


In [None]:
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 加载模型
model = load_model(CONFIG_FILE, CHECKPOINT_PATH, device=device)
print("Model loaded successfully!")


## 4. 加载数据集并测试


In [None]:
# 从数据集中加载样本
train_jsonl = DATA_ROOT / "train.jsonl"

if train_jsonl.exists():
    with open(train_jsonl, 'r') as f:
        train_data = [json.loads(line) for line in f if line.strip()]
    print(f"Loaded {len(train_data)} training samples")
    
    # 显示一个样本的数据结构
    print(f"\nSample data structure:")
    print(json.dumps(train_data[0], indent=2, ensure_ascii=False))
else:
    print(f"Training data not found at {train_jsonl}")
    train_data = []


In [None]:
# 测试随机选择的样本
if train_data:
    sample = random.choice(train_data)
    
    # 获取图像路径
    image_path = IMAGE_ROOT / sample['filename']
    
    # 从regions中提取物体名称作为prompt
    if 'grounding' in sample and 'regions' in sample['grounding']:
        objects = [r['phrase'] for r in sample['grounding']['regions']]
        prompt = ' . '.join(objects) + ' .'
    else:
        prompt = "object"
    
    print(f"Image: {image_path}")
    print(f"Prompt: {prompt}")
    print(f"Ground truth regions: {sample.get('grounding', {}).get('regions', [])}")
    
    # 运行检测
    if image_path.exists():
        boxes, phrases, scores = visualize_detection(
            image_path, prompt, model,
            box_threshold=0.3,
            text_threshold=0.25,
            device=device
        )
    else:
        print(f"Image not found: {image_path}")


## 5. 批量可视化


In [None]:
def visualize_batch(samples, model, image_root, num_samples=6, box_threshold=0.3, 
                    text_threshold=0.25, device="cuda", cols=3):
    """批量可视化多个样本"""
    # 随机选择样本
    selected = random.sample(samples, min(num_samples, len(samples)))
    
    rows = (len(selected) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
    axes = axes.flatten() if rows > 1 or cols > 1 else [axes]
    
    for idx, sample in enumerate(selected):
        ax = axes[idx]
        
        image_path = Path(image_root) / sample['filename']
        
        if not image_path.exists():
            ax.set_title(f"Image not found")
            ax.axis('off')
            continue
        
        # 获取prompt
        if 'grounding' in sample and 'regions' in sample['grounding']:
            objects = [r['phrase'] for r in sample['grounding']['regions']]
            prompt = ' . '.join(objects) + ' .'
        else:
            prompt = "object"
        
        # 加载图像
        image_pil, image_tensor = load_image(image_path)
        
        # 运行推理
        boxes, phrases, scores = get_grounding_output(
            model, image_tensor, prompt, box_threshold, text_threshold, device
        )
        
        # 绘制结果
        image_with_boxes = plot_boxes_to_image(image_pil, boxes, phrases, scores)
        
        ax.imshow(image_with_boxes)
        title = prompt[:40] + "..." if len(prompt) > 40 else prompt
        ax.set_title(f"Prompt: {title}", fontsize=10)
        ax.axis('off')
    
    # 隐藏多余的子图
    for idx in range(len(selected), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
# 批量可视化
if train_data:
    visualize_batch(
        train_data, model, IMAGE_ROOT,
        num_samples=6,
        box_threshold=0.3,
        text_threshold=0.25,
        device=device,
        cols=3
    )


## 6. 与Ground Truth对比


In [None]:
def visualize_with_gt(sample, model, image_root, box_threshold=0.3, text_threshold=0.25, device="cuda"):
    """可视化预测结果与Ground Truth的对比"""
    image_path = Path(image_root) / sample['filename']
    
    if not image_path.exists():
        print(f"Image not found: {image_path}")
        return
    
    image_pil, image_tensor = load_image(image_path)
    W, H = image_pil.size
    
    # 获取Ground Truth
    gt_boxes = []
    gt_labels = []
    if 'grounding' in sample and 'regions' in sample['grounding']:
        for region in sample['grounding']['regions']:
            bbox = region['bbox']  # [x1, y1, x2, y2] in pixels
            # 转换为normalized xywh
            x1, y1, x2, y2 = bbox
            cx = (x1 + x2) / 2 / W
            cy = (y1 + y2) / 2 / H
            w = (x2 - x1) / W
            h = (y2 - y1) / H
            gt_boxes.append([cx, cy, w, h])
            gt_labels.append(region['phrase'])
    
    gt_boxes = torch.tensor(gt_boxes) if gt_boxes else torch.zeros(0, 4)
    
    # 获取prompt
    prompt = ' . '.join(gt_labels) + ' .' if gt_labels else "object"
    
    # 运行推理
    pred_boxes, pred_phrases, pred_scores = get_grounding_output(
        model, image_tensor, prompt, box_threshold, text_threshold, device
    )
    
    # 绘制GT和预测结果
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # 原图
    axes[0].imshow(image_pil)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    # Ground Truth
    gt_image = plot_boxes_to_image(image_pil, gt_boxes, gt_labels, show_scores=False)
    axes[1].imshow(gt_image)
    axes[1].set_title(f"Ground Truth ({len(gt_labels)} objects)")
    axes[1].axis('off')
    
    # 预测结果
    pred_image = plot_boxes_to_image(image_pil, pred_boxes, pred_phrases, pred_scores)
    axes[2].imshow(pred_image)
    axes[2].set_title(f"Predictions ({len(pred_boxes)} detections)")
    axes[2].axis('off')
    
    plt.suptitle(f"Prompt: {prompt}", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # 打印详细信息
    print(f"\nGround Truth: {gt_labels}")
    print(f"Predictions: {list(zip(pred_phrases, [f'{s:.3f}' for s in pred_scores]))}")


In [None]:
# 与GT对比
if train_data:
    sample = random.choice(train_data)
    visualize_with_gt(sample, model, IMAGE_ROOT, box_threshold=0.3, text_threshold=0.25, device=device)


## 7. 阈值敏感性分析


In [None]:
def analyze_thresholds(image_path, prompt, model, device="cuda"):
    """分析不同阈值对检测结果的影响"""
    box_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
    text_threshold = 0.25
    
    image_pil, image_tensor = load_image(image_path)
    
    fig, axes = plt.subplots(1, len(box_thresholds), figsize=(4*len(box_thresholds), 4))
    
    for idx, box_thresh in enumerate(box_thresholds):
        boxes, phrases, scores = get_grounding_output(
            model, image_tensor, prompt, box_thresh, text_threshold, device
        )
        
        image_with_boxes = plot_boxes_to_image(image_pil, boxes, phrases, scores)
        
        axes[idx].imshow(image_with_boxes)
        axes[idx].set_title(f"Box thresh: {box_thresh}\n({len(boxes)} detections)")
        axes[idx].axis('off')
    
    plt.suptitle(f"Prompt: {prompt}", fontsize=12)
    plt.tight_layout()
    plt.show()

# 阈值分析
if train_data:
    sample = random.choice(train_data)
    image_path = IMAGE_ROOT / sample['filename']
    
    if 'grounding' in sample and 'regions' in sample['grounding']:
        objects = [r['phrase'] for r in sample['grounding']['regions']]
        prompt = ' . '.join(objects) + ' .'
    else:
        prompt = "object"
    
    if image_path.exists():
        analyze_thresholds(image_path, prompt, model, device)


## 8. 自定义测试


In [None]:
# ============================================
# 自定义测试 - 修改下面的路径和prompt进行测试
# ============================================

# 方式1: 指定图像路径和自定义prompt
# custom_image_path = PROJECT_ROOT / "path/to/your/image.jpg"
# custom_prompt = "bowl . plate . drawer ."

# 方式2: 使用训练数据中的图像,但用自定义prompt
if train_data:
    # 选择一个样本
    sample = train_data[0]  # 或者 random.choice(train_data)
    custom_image_path = IMAGE_ROOT / sample['filename']
    
    # 自定义prompt - 可以尝试不同的物体名称
    custom_prompt = "bowl . plate . cup . drawer . table ."
    
    print(f"Testing with custom prompt:")
    print(f"  Image: {custom_image_path}")
    print(f"  Prompt: {custom_prompt}")
    
    if custom_image_path.exists():
        boxes, phrases, scores = visualize_detection(
            custom_image_path, custom_prompt, model,
            box_threshold=0.25,  # 可以调低阈值看更多检测结果
            text_threshold=0.2,
            device=device
        )


## 9. 保存预测结果


In [None]:
def save_prediction(image_path, prompt, model, output_path, box_threshold=0.3, 
                    text_threshold=0.25, device="cuda"):
    """保存预测结果图像"""
    image_pil, image_tensor = load_image(image_path)
    boxes, phrases, scores = get_grounding_output(
        model, image_tensor, prompt, box_threshold, text_threshold, device
    )
    
    image_with_boxes = plot_boxes_to_image(image_pil, boxes, phrases, scores)
    image_with_boxes.save(output_path)
    print(f"Saved to: {output_path}")
    
    return boxes, phrases, scores

# 保存预测结果示例
output_dir = SCRIPT_DIR / "visualization_outputs"
output_dir.mkdir(exist_ok=True)

if train_data:
    # 保存多张预测结果
    num_to_save = min(5, len(train_data))
    for i, sample in enumerate(train_data[:num_to_save]):
        image_path = IMAGE_ROOT / sample['filename']
        
        if 'grounding' in sample and 'regions' in sample['grounding']:
            objects = [r['phrase'] for r in sample['grounding']['regions']]
            prompt = ' . '.join(objects) + ' .'
        else:
            prompt = "object"
        
        if image_path.exists():
            output_path = output_dir / f"prediction_{i:03d}.jpg"
            save_prediction(image_path, prompt, model, output_path, device=device)
    
    print(f"\nSaved {num_to_save} predictions to: {output_dir}")


In [None]:
# 此cell内容已移至前面的"3. 加载模型"部分
pass
