In [1]:
# ==============================================================================
# 0. 关键依赖库检查 (用于调试)
# ==============================================================================
print(">>> [DEBUG] 步骤 0: 检查关键库版本...")
try:
    import mmcv
    import timm
    print(f">>> [INFO] mmcv version: {mmcv.__version__}")
    print(f">>> [INFO] timm version: {timm.__version__}")
except ImportError as e:
    print(f"!!! [ERROR] 缺少核心调试库: {e}")
print(">>> [DEBUG] 步骤 0: 检查完成。\n" + "="*60 + "\n")

# ==============================================================================
# 1. 导入必要的库
# ==============================================================================
print(">>> [DEBUG] 步骤 1: 开始导入核心库...")
try:
    import cv2
    import torch
    import numpy as np
    from ultralytics import YOLO
    import sys
    import os
    from tqdm import tqdm
    from mmcv import Config
    print(">>> [DEBUG] 核心库 cv2, torch, numpy, ultralytics, tqdm, mmcv.Config 导入成功。")
except ImportError as e:
    print(f"!!! [ERROR] 导入核心库失败: {e}")
    print("!!! [HINT] 请确保您已经按照教程正确安装了所有依赖。")
    raise

# --- 导入 Metric3D 相关的模块 ---
METRIC3D_PATH = '/root/autodl-tmp/Metric3D'
if METRIC3D_PATH not in sys.path:
    sys.path.insert(0, METRIC3D_PATH)
    print(f">>> [DEBUG] 已将 '{METRIC3D_PATH}' 添加到系统路径。")

try:
    from mono.model.monodepth_model import DepthModel as MonoDepthModel
    print(">>> [DEBUG] Metric3D 模块 'DepthModel' (作为 MonoDepthModel) 导入成功。")
except ImportError as e:
    print(f"!!! [ERROR] 从 Metric3D 导入模块失败: {e}")
    print(f"!!! [HINT] 请确认 Metric3D 的代码库是否存在于 '{METRIC3D_PATH}' 路径下。")
    raise

print(">>> [DEBUG] 步骤 1: 所有库导入完成。\n" + "="*60 + "\n")

# ==============================================================================
# 2. 配置区域与路径检查
# ==============================================================================
print(">>> [DEBUG] 步骤 2: 配置模型和文件路径...")

YOLO_MODEL_PATH = '/root/autodl-tmp/weights/epoch30.pt'
METRIC3D_MODEL_PATH = '/root/autodl-tmp/weights/metric_depth_vit_large_800k.pth'
METRIC3D_CONFIG_PATH = '/root/autodl-tmp/Metric3D/mono/configs/HourglassDecoder/vit.raft5.large.py'
INPUT_VIDEO_PATH = '/root/autodl-tmp/kitti_videos/0002.mp4'
OUTPUT_VIDEO_PATH = '/root/autodl-tmp/output_video_with_depth2.mp4'
TRACKER_CONFIG_PATH = '/root/autodl-tmp/bytetrack.yaml'


paths_to_check = {
    "YOLOv8 权重": YOLO_MODEL_PATH,
    "Metric3D 权重": METRIC3D_MODEL_PATH,
    "Metric3D 配置": METRIC3D_CONFIG_PATH,
    "输入视频": INPUT_VIDEO_PATH,
    "跟踪器配置": TRACKER_CONFIG_PATH,
}
all_paths_ok = True
for name, path in paths_to_check.items():
    if not os.path.exists(path):
        print(f"!!! [ERROR] 路径检查失败: {name} 文件未找到于 '{path}'")
        all_paths_ok = False

if not all_paths_ok:
    raise FileNotFoundError("一个或多个关键文件路径无效。请确保已创建 bytetrack.yaml 文件。")
else:
    print(">>> [DEBUG] 所有文件路径检查通过。")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f">>> [DEBUG] 将要使用的设备: {DEVICE}")
if DEVICE.type == 'cpu':
    print("!!! [WARNING] 未检测到 CUDA 设备，将使用 CPU 运行。速度会很慢！")

print(">>> [DEBUG] 步骤 2: 配置完成。\n" + "="*60 + "\n")

# ==============================================================================
# 3. 模型加载
# ==============================================================================
print(">>> [DEBUG] 步骤 3: 开始加载深度学习模型...")
# --- 加载 YOLOv8 & ByteTrack 模型 ---
try:
    print(">>> [DEBUG] 正在加载 YOLOv8 模型...")
    yolo_model = YOLO(YOLO_MODEL_PATH)
    print(">>> [DEBUG] YOLOv8 模型加载成功！")

    # --- [新功能] 获取要跟踪的类别ID ---
    TARGET_CLASS_NAME = 'Car'
    TARGET_CLASS_ID = -1
    # 打印模型所有类别，方便确认
    print(f">>> [INFO] YOLOv8 模型所有类别: {yolo_model.names}")
    # 自动查找'Car'类别的ID
    for class_id, class_name in yolo_model.names.items():
        if class_name == TARGET_CLASS_NAME:
            TARGET_CLASS_ID = class_id
            break
    
    if TARGET_CLASS_ID != -1:
        print(f">>> [INFO] 目标类别 '{TARGET_CLASS_NAME}' 已找到, ID为: {TARGET_CLASS_ID}")
    else:
        raise ValueError(f"错误：目标类别 '{TARGET_CLASS_NAME}' 在模型中未找到。")

except Exception as e:
    print(f"!!! [ERROR] 加载 YOLOv8 模型或查找类别ID时失败: {e}")
    raise

# --- 加载 Metric3Dv2 模型 ---
try:
    print(">>> [DEBUG] 正在加载 Metric3Dv2 模型...")
    
    cfg = Config.fromfile(METRIC3D_CONFIG_PATH)
    print(">>> [DEBUG] Part A: 配置加载成功。")

    cfg.model.backbone.use_mask_token = False
        
    metric3d_model = MonoDepthModel(cfg).to(DEVICE)
    print(">>> [DEBUG] Part B: 模型初始化成功。")

    checkpoint = torch.load(METRIC3D_MODEL_PATH, map_location=DEVICE)
    
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint

    metric3d_model.load_state_dict(state_dict, strict=False)
    print(">>> [DEBUG] Part C: 权重加载成功 (已忽略不匹配的键)。")
    
    metric3d_model.eval()
    print(">>> [SUCCESS] Metric3Dv2 模型加载并移动到 GPU 成功！")
except Exception as e:
    print(f"!!! [FATAL ERROR] 加载 Metric3Dv2 模型时出错: {e}")
    import traceback
    traceback.print_exc()
    raise

print(">>> [DEBUG] 步骤 3: 所有模型加载完成。\n" + "="*60 + "\n")

# ==============================================================================
# 4. 视频处理主函数
# ==============================================================================
print(">>> [DEBUG] 步骤 4: 定义视频处理函数...")
def process_video_debug(input_path, output_path):
    print("\n--- 开始视频处理 ---")
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        print(f"!!! [ERROR] 无法打开视频文件: {input_path}")
        return

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    metric3d_input_size = (cfg.data_basic['vit_size'][1], cfg.data_basic['vit_size'][0])
    print(f">>> [INFO] Metric3D 模型输入尺寸 (宽, 高): {metric3d_input_size}")

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    print(f">>> [INFO] 输入视频信息: {width}x{height} @ {fps:.2f} FPS, 共 {total_frames} 帧。")
    print(f">>> [INFO] 处理后的视频将保存至: {output_path}")

    with tqdm(total=total_frames, desc="视频处理进度") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # --- [核心修改 1] ---
            # 使用 classes 参数指定只跟踪 'Car' 类别
            track_results = yolo_model.track(
                frame, 
                persist=True, 
                verbose=False, 
                tracker=TRACKER_CONFIG_PATH,
                classes=[TARGET_CLASS_ID] 
            )
            
            # --- [核心修改 2] ---
            # 不再使用 track_results[0].plot()，改为手动绘制
            # 首先创建一个当前帧的副本用于绘制
            annotated_frame = frame.copy()

            # 深度估计部分 (逻辑不变)
            with torch.no_grad():
                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                rgb_frame_resized = cv2.resize(rgb_frame, metric3d_input_size)
                rgb_torch = torch.from_numpy(rgb_frame_resized).permute(2, 0, 1).unsqueeze(0).float().to(DEVICE) / 255.0
                pred_output = metric3d_model(data={'input': rgb_torch})
                pred_depth = pred_output[0]
                pred_depth_np = pred_depth.squeeze().cpu().numpy()
                pred_depth_resized = cv2.resize(pred_depth_np, (width, height))

            # 获取跟踪结果
            boxes = track_results[0].boxes.xyxy.cpu().numpy()
            track_ids = []
            if track_results[0].boxes.id is not None:
                track_ids = track_results[0].boxes.id.int().cpu().tolist()

            # 循环遍历每个被跟踪到的目标，手动绘制
            if len(track_ids) > 0:
                for box, track_id in zip(boxes, track_ids):
                    x1, y1, x2, y2 = map(int, box)
                    
                    # 1. 绘制检测框 (绿色，粗细为2)
                    cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    
                    # 2. 计算并绘制ID和深度信息 (逻辑不变)
                    box_w, box_h = x2 - x1, y2 - y1
                    roi_w, roi_h = int(box_w * 0.5), int(box_h * 0.5)
                    roi_x1 = max(x1 + (box_w - roi_w) // 2, 0)
                    roi_y1 = max(y1 + (box_h - roi_h) // 2, 0)
                    roi_x2 = min(roi_x1 + roi_w, width)
                    roi_y2 = min(roi_y1 + roi_h, height)

                    depth_roi = pred_depth_resized[roi_y1:roi_y2, roi_x1:roi_x2]
                    
                    if depth_roi.size > 0:
                        sorted_depths = np.sort(depth_roi.flatten())
                        cut_off = int(len(sorted_depths) * 0.05)
                        
                        if len(sorted_depths) > 2 * cut_off:
                            filtered_depths = sorted_depths[cut_off:-cut_off]
                            avg_depth = np.mean(filtered_depths) if filtered_depths.size > 0 else 0
                        else:
                            avg_depth = np.mean(sorted_depths) if sorted_depths.size > 0 else 0
                        
                        depth_text = f"ID:{track_id} D:{avg_depth:.2f}m"
                        (text_w, text_h), _ = cv2.getTextSize(depth_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
                        cv2.rectangle(annotated_frame, (x1, y1 - 25), (x1 + text_w + 5, y1 - 5), (0, 100, 0), -1)
                        cv2.putText(annotated_frame, depth_text, (x1 + 2, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

            out.write(annotated_frame)
            pbar.update(1)

    cap.release()
    out.release()
    cv2.destroyAllWindows()
    print(f"\n--- 视频处理完成！ ---")
    print(f">>> [SUCCESS] 输出视频已成功保存到: {output_path}")

print(">>> [DEBUG] 步骤 4: 视频处理函数定义完成。\n" + "="*60 + "\n")

# ==============================================================================
# 5. 运行主程序
# ==============================================================================
print(">>> [DEBUG] 步骤 5: 开始执行主程序...")
try:
    process_video_debug(INPUT_VIDEO_PATH, OUTPUT_VIDEO_PATH)
except Exception as e:
    print(f"!!! [FATAL ERROR] 在视频处理过程中发生严重错误: {e}")
    import traceback
    traceback.print_exc()
    print("!!! [HINT] 请检查上面的错误信息。可能的原因包括：CUDA内存不足、模型与输入数据维度不匹配等。")

print(">>> [DEBUG] 步骤 5: 主程序执行完毕。\n" + "="*60)

>>> [DEBUG] 步骤 0: 检查关键库版本...


  from pkg_resources import packaging  # type: ignore[attr-defined]
  from .autonotebook import tqdm as notebook_tqdm


>>> [INFO] mmcv version: 1.7.2
>>> [INFO] timm version: 0.6.12
>>> [DEBUG] 步骤 0: 检查完成。

>>> [DEBUG] 步骤 1: 开始导入核心库...
>>> [DEBUG] 核心库 cv2, torch, numpy, ultralytics, tqdm, mmcv.Config 导入成功。
>>> [DEBUG] 已将 '/root/autodl-tmp/Metric3D' 添加到系统路径。
>>> [DEBUG] Metric3D 模块 'DepthModel' (作为 MonoDepthModel) 导入成功。
>>> [DEBUG] 步骤 1: 所有库导入完成。

>>> [DEBUG] 步骤 2: 配置模型和文件路径...
>>> [DEBUG] 所有文件路径检查通过。
>>> [DEBUG] 将要使用的设备: cuda
>>> [DEBUG] 步骤 2: 配置完成。

>>> [DEBUG] 步骤 3: 开始加载深度学习模型...
>>> [DEBUG] 正在加载 YOLOv8 模型...
>>> [DEBUG] YOLOv8 模型加载成功！
>>> [INFO] YOLOv8 模型所有类别: {0: 'Car', 1: 'Pedestrian', 2: 'Cyclist'}
>>> [INFO] 目标类别 'Car' 已找到, ID为: 0
>>> [DEBUG] 正在加载 Metric3Dv2 模型...
>>> [DEBUG] Part A: 配置加载成功。
>>> [DEBUG] Part B: 模型初始化成功。
>>> [DEBUG] Part C: 权重加载成功 (已忽略不匹配的键)。
>>> [SUCCESS] Metric3Dv2 模型加载并移动到 GPU 成功！
>>> [DEBUG] 步骤 3: 所有模型加载完成。

>>> [DEBUG] 步骤 4: 定义视频处理函数...
>>> [DEBUG] 步骤 4: 视频处理函数定义完成。

>>> [DEBUG] 步骤 5: 开始执行主程序...

--- 开始视频处理 ---
>>> [INFO] Metric3D 模型输入尺寸 (宽, 高): (1064, 616)
>>> [INFO] 输入视频信息

视频处理进度:   3%|▎         | 6/233 [00:02<01:43,  2.19it/s]


KeyboardInterrupt: 

In [5]:
%pip install -U ultralytics

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
%pip install lap

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting lap
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2f/ab/070be2dc9e56b368031168710c848be203523c7c83d9d22ce7fde6a167fe/lap-0.5.12-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m40.6 MB/s[0m  [33m0:00:00[0m
Installing collected packages: lap
Successfully installed lap-0.5.12
[0mNote: you may need to restart the kernel to use updated packages.
