# 1. Generate Fully annotated results

In [2]:
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from scipy.ndimage import convolve
from ultralytics import YOLO

# ========== 参数 ==========
source_dir = './samples/'
output_dir = './pitch_prediction/'
model_path = './trained_models/yolov8m_500ep.pt'
os.makedirs(output_dir, exist_ok=True)

gray_threshold = 250
kernel_width = 100
response_thresh = 10
min_black_ratio = 0.5
min_gap_between_lines = 4
max_gap_in_group = 30
pitch_labels = ['F5', 'D5', 'B4', 'G4', 'E4']

model = YOLO(model_path)
font = ImageFont.load_default()

# ========== 谱线检测函数 ==========
def detect_staff_lines_np(img_np):
    binary = (img_np < gray_threshold).astype(np.uint8)
    kernel = np.ones((1, kernel_width), dtype=np.uint8)
    response = convolve(binary, kernel)
    candidates = (response > response_thresh).astype(np.uint8)
    image_width = img_np.shape[1]
    min_black = int(min_black_ratio * image_width)

    valid_y = [
        y for y in range(binary.shape[0])
        if np.max(candidates[y]) > 0 and np.sum(binary[y]) >= min_black
    ]

    def deduplicate_lines(y_coords, min_gap=4):
        y_coords = sorted(y_coords)
        deduped = []
        for y in y_coords:
            if not deduped or abs(y - deduped[-1]) >= min_gap:
                deduped.append(y)
        return deduped

    return deduplicate_lines(valid_y, min_gap_between_lines)

# ========== 分组为五线谱，每组严格限制为5条 ==========
def group_staff_lines(y_coords):
    y_coords = sorted(y_coords)
    groups = []
    group = []

    for y in y_coords:
        group.append(y)
        if len(group) == 5:
            groups.append(group)
            group = []

    # 丢弃不满5条的剩余行（也可以改成保留）
    return groups

# ========== 主处理流程 ==========
image_files = [f for f in os.listdir(source_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
for image_file in image_files:
    img_path = os.path.join(source_dir, image_file)
    full_img = Image.open(img_path).convert('RGB')
    gray = full_img.convert('L')
    np_gray = np.array(gray)
    draw = ImageDraw.Draw(full_img)

    staff_lines = detect_staff_lines_np(np_gray)
    staff_groups = group_staff_lines(staff_lines)

    for i in range(0, len(staff_groups), 2):
        groups_in_pair = staff_groups[i:i+2]
        if not groups_in_pair:
            continue

        # 获取上下边界
        all_ys = [y for group in groups_in_pair for y in group]
        y_min = max(0, min(all_ys) - 40)
        y_max = min(np_gray.shape[0], max(all_ys) + 40)
        cropped_img = full_img.crop((0, y_min, full_img.width, y_max))

        # YOLO 推理
        results = model(cropped_img)
        result = results[0]
        boxes = result.boxes.data if result.boxes is not None else []

        # 构建谱线→pitch 映射
        line_pitch_map = {}
        for j, group in enumerate(groups_in_pair):
            sorted_group = sorted(group)
            for k, y in enumerate(sorted_group):
                if k < len(pitch_labels):
                    line_pitch_map[y] = pitch_labels[k]

        # 标注预测框
        for box in boxes:
            x1, y1, x2, y2 = [int(x) for x in box[:4]]
            y_center = (y1 + y2) // 2 + y_min
            x1 += 0
            x2 += 0
            y1 += y_min
            y2 += y_min

            if not line_pitch_map:
                continue

            closest_y = min(line_pitch_map.keys(), key=lambda y: abs(y - y_center))
            assigned_pitch = line_pitch_map[closest_y]

            # 获取类别名
            cls_id = int(box[5].item()) if len(box) > 5 else -1
            class_name = result.names[cls_id] if cls_id in result.names else 'unknown'

            # 绘图
            label_text = f"{class_name} - {assigned_pitch}"
            draw.rectangle([(x1, y1), (x2, y2)], outline='green', width=2)
            draw.text((x1, y1 - 15), label_text, fill='green', font=font)

        # 画谱线和标签
        for group_idx, group in enumerate(groups_in_pair):
            for line_idx, y in enumerate(sorted(group)):
                pitch = pitch_labels[line_idx] if line_idx < len(pitch_labels) else 'Unknown'
                draw.line([(0, y), (full_img.width, y)], fill='red', width=1)
                draw.text((100, y - 10), f"Group {i+group_idx+1} - Line {line_idx+1} - {pitch} (y={y})", fill='blue', font=font)

    # 保存
    save_path = os.path.join(output_dir, image_file)
    full_img.save(save_path)
    print(f"[✓] Saved: {save_path}")



0: 192x992 1 brace, 1 clefG, 1 clefF, 4 timeSig4s, 15 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 4 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 2 staffs, 128.7ms
Speed: 4.4ms preprocess, 128.7ms inference, 5.0ms postprocess per image at shape (1, 3, 192, 992)
[✓] Saved: ./pitch_prediction/line1.png

0: 224x992 1 brace, 1 clefG, 1 clefF, 4 timeSig4s, 15 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 4 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 2 staffs, 154.1ms
Speed: 0.8ms preprocess, 154.1ms inference, 0.8ms postprocess per image at shape (1, 3, 224, 992)

0: 224x992 1 brace, 1 clefG, 1 clefF, 16 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 8 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 1 restQuarter, 1 slur, 1 beam, 2 staffs, 167.4ms
Speed: 1.0ms preprocess, 167.4ms inference, 0.9ms postprocess per image at shape (1, 3, 224, 992)

0: 224x992 1 brace, 1 clefG, 1 clefF, 24 noteheadBlackOnLines, 26 noteheadBlackInSpaces, 2 noteheadHalfOnLines, 2 flag8thUps, 2 flag8thDo