In [11]:
# %%
from pathlib import Path
import re

# === 必填：你的文件路径 ===
XML_PATH = Path(r"E:\bridgev2_DATA\extracted_first_test_2\batch_1_annotations.xml")   # 例：/mnt/data/annotations.xml 或 E:\annotations.xml
IMG_DIR  = Path(r"E:\bridgev2_DATA\extracted_first_test_2\batch_00001")                          # 你的图片根目录（无子文件夹）

# === 结果保存位置：同目录 or 单独目录 ===
SAVE_IN_PLACE = False
OUT_DIR = Path(r"E:\bridgev2_DATA\test_traj\batch_00001")  # 仅当 SAVE_IN_PLACE=False 时生效

# 颜色（BGR，OpenCV）
SKYBLUE = (235, 206, 135)  # 天蓝色 (RGB 135,206,235) -> BGR
RED     = (0, 0, 255)

# 线宽和箭头参数
LINE_THICKNESS = 3
HEAD_LEN       = 13   # 菱形头沿箭向的长度（像素）
HEAD_WIDTH     = 6   # 菱形头的半宽（法向展开一半）
DIAMOND_RATIO  = 0.6  # 菱形尾点离头尖的比例（>0）


In [12]:
# %%
import xml.etree.ElementTree as ET

def parse_cvat_xml_four_points(xml_path: Path):
    tree = ET.parse(str(xml_path))
    root = tree.getroot()

    name2pts = {}
    for img_tag in root.findall(".//image"):
        name = img_tag.attrib.get("name", "").strip()
        poly = img_tag.find(".//polyline")
        if poly is None:
            # 没有 polyline，跳过
            continue
        pts_str = poly.attrib.get("points", "").strip()
        # "x1,y1;x2,y2;x3,y3;x4,y4"
        pairs = [p.strip() for p in pts_str.split(";") if p.strip()]
        if len(pairs) != 4:
            print(f"[WARN] {name}: points != 4 ({len(pairs)}). Skip.")
            continue
        pts = []
        ok = True
        for p in pairs:
            try:
                x_str, y_str = p.split(",")
                x, y = float(x_str), float(y_str)
                pts.append((x, y))
            except Exception as e:
                print(f"[WARN] {name}: bad point '{p}': {e}. Skip.")
                ok = False
                break
        if ok:
            name2pts[name] = pts  # [(Ax,Ay),(Bx,By),(Cx,Cy),(Dx,Dy)]
    return name2pts

name2pts = parse_cvat_xml_four_points(XML_PATH)
print(f"Parsed {len(name2pts)} images from XML.")


Parsed 516 images from XML.


In [13]:
# %%
import cv2
import numpy as np

def draw_diamond_arrow(img, p0, p1, color, thickness=2, head_len=24, head_half_width=12, diamond_ratio=0.6):
    """
    在 img 上从 p0 -> p1 画一支菱形箭头（BGR颜色）。
      - p0, p1: (x, y) 浮点或整数
      - head_len: 沿箭向的头部长度
      - head_half_width: 头部的半宽（法向）
      - diamond_ratio: 决定菱形尾点的后移比例（越大菱形越“长”）
    """
    x0, y0 = float(p0[0]), float(p0[1])
    x1, y1 = float(p1[0]), float(p1[1])
    v = np.array([x1 - x0, y1 - y0], dtype=np.float32)
    L = np.linalg.norm(v)
    if L < 1e-6:
        return
    u = v / L  # 单位方向
    n = np.array([-u[1], u[0]], dtype=np.float32)  # 法向

    # 箭身：止于头部开始位置
    shaft_end = np.array([x1, y1], dtype=np.float32) - u * head_len
    cv2.line(
        img,
        (int(round(x0)), int(round(y0))),
        (int(round(shaft_end[0])), int(round(shaft_end[1]))),
        color=color,
        thickness=thickness,
        lineType=cv2.LINE_AA
    )

    # 菱形 4 点：尖 -> 右腰 -> 尾 -> 左腰（沿箭向定义）
    tip      = np.array([x1, y1], dtype=np.float32)
    waist    = tip - u * head_len
    tail     = tip - u * head_len * (1.0 + diamond_ratio)
    p_right  = waist + n * head_half_width
    p_left   = waist - n * head_half_width

    poly = np.array([tip, p_right, tail, p_left], dtype=np.int32).reshape(-1,1,2)
    cv2.fillPoly(img, [poly], color=color, lineType=cv2.LINE_AA)


In [14]:
# %%
ID_STEP_RE = re.compile(r"^(id\d+)_step\d+\.png$", re.IGNORECASE)

def make_output_name(input_name: str) -> str:
    m = ID_STEP_RE.match(input_name)
    if m:
        return f"{m.group(1)}_traj_raw.png"
    # fallback：直接加后缀
    stem, ext = (Path(input_name).stem, Path(input_name).suffix)
    return f"{stem}_traj_raw{ext or '.png'}"


In [15]:
# %%
from tqdm import tqdm

if not SAVE_IN_PLACE:
    OUT_DIR.mkdir(parents=True, exist_ok=True)

num_ok, num_skip = 0, 0

for name, (A, B, C, D) in tqdm(name2pts.items(), desc="Drawing arrows"):
    img_path = IMG_DIR / name
    if not img_path.exists():
        print(f"[WARN] image not found: {img_path}")
        num_skip += 1
        continue

    img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
    if img is None:
        print(f"[WARN] failed to read: {img_path}")
        num_skip += 1
        continue

    # A->B （天蓝）
    draw_diamond_arrow(
        img, A, B, color=SKYBLUE,
        thickness=LINE_THICKNESS,
        head_len=HEAD_LEN,
        head_half_width=HEAD_WIDTH,
        diamond_ratio=DIAMOND_RATIO
    )
    # C->D （红）
    draw_diamond_arrow(
        img, C, D, color=RED,
        thickness=LINE_THICKNESS,
        head_len=HEAD_LEN,
        head_half_width=HEAD_WIDTH,
        diamond_ratio=DIAMOND_RATIO
    )

    out_name = make_output_name(name)
    out_path = (IMG_DIR / out_name) if SAVE_IN_PLACE else (OUT_DIR / out_name)
    ok = cv2.imwrite(str(out_path), img)
    if not ok:
        print(f"[WARN] failed to save: {out_path}")
        num_skip += 1
    else:
        num_ok += 1

print(f"[DONE] saved={num_ok}, skipped={num_skip}")


Drawing arrows:   0%|          | 0/516 [00:00<?, ?it/s]

Drawing arrows: 100%|██████████| 516/516 [00:16<00:00, 31.58it/s]

[DONE] saved=516, skipped=0



