In [89]:
import os
import json
from ultralytics import YOLO
from PIL import Image

# 加载 YOLO 模型
model = YOLO("YOLOv8x_Symbols.pt")

# 设置图像路径
source_dir = "./samples/"
image_files = [os.path.join(source_dir, f) for f in os.listdir(source_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

# 定义用于 MIDI 的标签集合
note_classes = {
    'noteheadBlack', 'noteheadHalf', 'noteheadWhole',
    'restQuarter', 'restHalf', 'restWhole',
    'noteheadBlackOnLine', 'noteheadBlackInSpace',
    'noteheadHalfOnLine', 'noteheadWholeInSpace', 'rest8th'
}
# staff 也需要保留，用于 pitch 匹配
extra_classes = {'staff'}

# 推理图像
results = model(image_files)

# 结果保存路径
save_json = "./samples/midi_notes.json"
output = []

for result in results:
    img_path = result.path
    boxes = result.boxes.data
    confs = result.boxes.conf
    cls_ids = result.boxes.cls

    notes = []
    for box, conf, cls_id in zip(boxes, confs, cls_ids):
        label = result.names[int(cls_id)]
        # if label not in note_classes and label not in extra_classes:
        #     continue  # 不是音符也不是五线谱，跳过

        x1, y1, x2, y2 = map(float, box[:4])
        note_data = {
            "label": label,
            "confidence": float(conf),
            "bbox": [x1, y1, x2, y2],
            "center": [(x1 + x2) / 2, (y1 + y2) / 2]
        }
        notes.append(note_data)

    output.append({
        "filename": os.path.basename(img_path),
        "notes": notes
    })

# 保存 JSON
with open(save_json, 'w') as f:
    json.dump(output, f, indent=2)

print(f"🎼 Done! MIDI-relevant predictions saved to {save_json}")



0: 992x992 1 brace, 1 clefG, 1 clefF, 4 timeSig4s, 15 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 4 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 2 staffs, 1296.2ms
1: 992x992 1 brace, 1 clefG, 1 clefF, 16 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 8 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 1 restQuarter, 1 beam, 1 tie, 2 staffs, 1296.2ms
Speed: 5.8ms preprocess, 1296.2ms inference, 2.2ms postprocess per image at shape (1, 3, 992, 992)
🎼 Done! MIDI-relevant predictions saved to ./samples/midi_notes.json


In [90]:
import os
import json
from PIL import Image, ImageDraw, ImageFont

def visualize_predictions(json_path, image_dir, save_dir):
    # 读取json
    with open(json_path, 'r') as f:
        data = json.load(f)

    os.makedirs(save_dir, exist_ok=True)

    for page in data:
        filename = page["filename"]
        image_path = os.path.join(image_dir, filename)

        if not os.path.exists(image_path):
            print(f"❗ Image {filename} not found!")
            continue

        img = Image.open(image_path).convert("RGB")
        draw = ImageDraw.Draw(img)

        try:
            font = ImageFont.truetype("arial.ttf", size=20)
        except:
            font = ImageFont.load_default()

        for note in page["notes"]:
            label = note["label"]
            conf = note["confidence"]
            bbox = note["bbox"]
            center = note["center"]

            x1, y1, x2, y2 = bbox
            cx, cy = center

            # 颜色区分
            if 'staff' in label:
                color = (0, 0, 255)  # 蓝色
            elif 'notehead' in label:
                color = (0, 255, 0)  # 绿色
            elif 'rest' in label:
                color = (255, 0, 0)  # 红色
            else:
                color = (255, 165, 0)  # 橙色其他

            # 画框
            draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
            # 写label
            draw.text((x1, y1-20), f"{label} {conf:.2f}", fill=color, font=font)
            # 标记center
            r = 3
            draw.ellipse([(cx-r, cy-r), (cx+r, cy+r)], fill=color)

        save_path = os.path.join(save_dir, filename)
        img.save(save_path)
        print(f"✅ Saved visualization: {save_path}")

# 使用
visualize_predictions(
    json_path="./samples/midi_notes.json",
    image_dir="./samples/",
    save_dir="./samples/visualized/"
)


✅ Saved visualization: ./samples/visualized/canon_partial.png
✅ Saved visualization: ./samples/visualized/canon_partial2.png


In [91]:
import json
import os
from PIL import Image, ImageDraw

def visualize_staffs(json_path, image_dir, save_dir="./samples/staff_visualization"):
    os.makedirs(save_dir, exist_ok=True)

    with open(json_path, 'r') as f:
        data = json.load(f)

    for page in data:
        filename = page["filename"]
        notes = page["notes"]
        staff_notes = [n for n in notes if n["label"] == "staff"]

        # 打开图像
        image_path = os.path.join(image_dir, filename)
        image = Image.open(image_path).convert("RGB")
        draw = ImageDraw.Draw(image)

        print(f"\n🖼 Image: {filename}")
        for i, s in enumerate(sorted(staff_notes, key=lambda n: n["center"][1])):
            x1, y1, x2, y2 = s["bbox"]
            cx, cy = s["center"]
            draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
            draw.line([(x1, cy), (x2, cy)], fill="blue", width=1)
            draw.text((x1, y1 - 10), f"staff {i+1}", fill="green")
            print(f"staff {i+1}: top={y1:.1f}, bottom={y2:.1f}, center={cy:.1f}")

        out_path = os.path.join(save_dir, f"staff_{filename}")
        image.save(out_path)
        print(f"✅ Saved visual to {out_path}")

# 示例使用
visualize_staffs(
    json_path="./samples/midi_notes.json",
    image_dir="./samples"
)


🖼 Image: canon_partial.png
staff 1: top=19.0, bottom=47.5, center=33.3
staff 2: top=91.7, bottom=120.1, center=105.9
✅ Saved visual to ./samples/staff_visualization\staff_canon_partial.png

🖼 Image: canon_partial2.png
staff 1: top=30.8, bottom=57.9, center=44.4
staff 2: top=100.1, bottom=127.2, center=113.7
✅ Saved visual to ./samples/staff_visualization\staff_canon_partial2.png


In [92]:
import json
import mido
from mido import Message, MidiFile, MidiTrack
import os
import math
def c_major_transform_clefG(i):
    dic = {
        6: 81,
        5: 79,
        4: 77,
        3: 76,
        2: 74,
        1: 72,
        0: 71,
        -1: 69,
        -2: 67,
        -3: 65,
        -4: 64,
        -5: 62,
        -6: 60
    }
    return dic.get(i, 60)
def c_major_transform_clefF(i):
    dic = {
        6: 60,
        5: 59,
        4: 57,
        3: 55,
        2: 53,
        1: 52,
        0: 50,
        -1: 48,
        -2: 47,
        -3: 45,
        -4: 43,
        -5: 41,
        -6: 40
    }
    return dic.get(i, 60)
def get_note_name(midi_note):
    """Returns the note name for a MIDI note number (e.g., 60 → C4)"""
    note_names = {
        0: 'C', 1: 'C#', 2: 'D', 3: 'D#', 4: 'E', 5: 'F',
        6: 'F#', 7: 'G', 8: 'G#', 9: 'A', 10: 'A#', 11: 'B'
    }
    octave = midi_note // 12 - 1  # MIDI note 60 is C4
    note = midi_note % 12
    return f"{note_names[note]}{octave}"
def json_to_midi(json_path, midi_path):
    # Read the JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Create a new MIDI file with standard resolution
    mid = MidiFile(ticks_per_beat=480)
    
    # Create tracks based on detected staves
    tracks = {}
    track_info = {}  # Store clef and other info for each track
    
    # Process each page in the JSON data
    for page in data:
        filename = page.get("filename", "unknown")
        notes = page.get("notes", [])
        
        print(f"Processing {filename} with {len(notes)} elements")
        
        # First pass: identify staves and their properties (clefs, time signatures)
        staves = [note for note in notes if note.get("label") == "staff"]
        
        # Sort staves by vertical position (top to bottom)
        staves = sorted(staves, key=lambda x: x.get("center", [0, 0])[1])
        
        # Get clefs and time signatures
        clefs = {note.get("label"): note for note in notes 
                if note.get("label") in ["clefG", "clefF", "clefC"]}
        
        time_sigs = [note for note in notes 
                    if note.get("label", "").startswith("timeSig")]
        
        # Assign staves to tracks
        for i, staff in enumerate(staves):
            staff_id = f"staff_{i}"
            staff_center_y = staff.get("center", [0, 0])[1]
            
            # Look for clef associated with this staff
            staff_clef = None
            for clef_type, clef in clefs.items():
                clef_center_y = clef.get("center", [0, 0])[1]
                # Associate clef with nearest staff
                if abs(clef_center_y - staff_center_y) < 30:  # Threshold for association
                    staff_clef = clef_type
                    break
            
            # Create track if it doesn't exist
            if staff_id not in tracks:
                tracks[staff_id] = MidiTrack()
                mid.tracks.append(tracks[staff_id])
                # Add track name
                tracks[staff_id].append(mido.MetaMessage('track_name', name=f"Staff {i+1}"))
                # Default to piano sound (General MIDI program 0)
                tracks[staff_id].append(mido.Message('program_change', program=0, time=0))
                
                # Store staff information
                track_info[staff_id] = {
                    "clef": staff_clef,
                    "staff_y": staff_center_y,
                    "bbox": staff.get("bbox"),
                    "time_signature": (4, 4)  # Default 4/4 time
                }
                
                print(f"Created track for staff {i+1} with clef: {staff_clef}")
    
        # Second pass: process notes for each staff
        for staff_id, staff_data in track_info.items():
            # Get staff bounding box to determine which notes belong to this staff
            staff_bbox = staff_data["bbox"]
            staff_top = staff_bbox[1]
            staff_bottom = staff_bbox[3]
            staff_center_y = staff_data["staff_y"]
            
            # Filter and sort notes that belong to this staff
            staff_notes = []
            for note in notes:
                note_label = note.get("label", "")
                if "notehead" in note_label:
                    note_center_y = note.get("center", [0, 0])[1]
                    # Check if note is within or close to staff
                    if (staff_top - 20 <= note_center_y <= staff_bottom + 20 or
                        abs(note_center_y - staff_center_y) < 40):
                        staff_notes.append(note)
            
            # Group notes by horizontal position (to detect chords)
            # A chord is defined as multiple notes that are vertically aligned (same x position)
            horizontal_tolerance = 10  # Pixels of tolerance for horizontal alignment
            
            # Sort notes by horizontal position
            staff_notes.sort(key=lambda x: x.get("center", [0, 0])[0])
            
            # Group notes into temporal segments (notes or chords)
            note_groups = []
            current_group = []
            
            for i, note in enumerate(staff_notes):
                x_pos = note.get("center", [0, 0])[0]
                
                if not current_group:
                    # First note in a group
                    current_group.append(note)
                elif abs(x_pos - current_group[0].get("center", [0, 0])[0]) <= horizontal_tolerance:
                    # Note horizontally aligned with current group (part of a chord)
                    current_group.append(note)
                else:
                    # New horizontal position, start a new group
                    note_groups.append(current_group)
                    current_group = [note]
            
            # Add the last group if it exists
            if current_group:
                note_groups.append(current_group)
            
            print(f"Processing {len(staff_notes)} notes in {len(note_groups)} groups for {staff_id}")
            
            # Process each group (single note or chord)
            for group in note_groups:
                # Sort notes in the group by vertical position (low to high)
                group.sort(key=lambda x: x.get("center", [0, 0])[1], reverse=True)
                
                # Determine duration based on the first note in the group
                # (assuming all notes in a chord have the same duration)
                note_type = group[0].get("label", "")
                
                # Determine note duration based on note type
                if "Whole" in note_type:
                    duration = 1920  # 4 beats * 480 ticks
                elif "Half" in note_type:
                    duration = 960   # 2 beats * 480 ticks
                elif "Quarter" in note_type:
                    duration = 480   # 1 beat * 480 ticks
                elif "Eighth" in note_type:
                    duration = 240   # 0.5 beat * 480 ticks
                elif "16th" in note_type:
                    duration = 120   # 0.25 beat * 480 ticks
                else:
                    duration = 480   # Default to quarter note
                
                # Process each note in the group
                chord_pitches = []
                for note in group:
                    note_center_y = note.get("center", [0, 0])[1]
                    
                    # Calculate pitch based on vertical position relative to staff
                    line_height = (staff_bottom - staff_top) / 4  # 5 lines = 4 spaces
                    distance_from_center = (staff_center_y - note_center_y) / (line_height/2)
                    
                    # Round to nearest staff position
                    staff_position = round(distance_from_center)
                    
                    # Determine pitch based on clef
                    if staff_data["clef"] == "clefG":
                        pitch = c_major_transform_clefG(staff_position)
                    elif staff_data["clef"] == "clefF":
                        pitch = c_major_transform_clefF(staff_position)
                    else:
                        # Default handling for unknown clef
                        pitch = 60 + staff_position  # Middle C + offset
                    
                    # Skip if note is outside reasonable MIDI range
                    if 0 <= pitch <= 127:
                        chord_pitches.append(pitch)
                        print(f"  Added note: {get_note_name(pitch)} (MIDI {pitch}) in group")
                
                # Play all notes in the chord simultaneously
                velocity = 64  # Default medium velocity
                
                for pitch in chord_pitches:
                    # Note on events all at the same time
                    tracks[staff_id].append(Message('note_on', note=pitch, velocity=velocity, time=0))
                
                # Add note_off events
                # First note_off has the full duration, others immediately follow
                for i, pitch in enumerate(chord_pitches):
                    if i == 0:
                        tracks[staff_id].append(Message('note_off', note=pitch, velocity=0, time=duration))
                    else:
                        tracks[staff_id].append(Message('note_off', note=pitch, velocity=0, time=0))
    
    # Check if tracks are empty and create a default one if needed
    if not tracks:
        print("No valid tracks found, creating default track")
        default_track = MidiTrack()
        mid.tracks.append(default_track)
        default_track.append(mido.MetaMessage('track_name', name="Default Track"))
        default_track.append(mido.Message('program_change', program=0, time=0))
    
    # Save the MIDI file
    mid.save(midi_path)
    print(f"MIDI file saved to {midi_path}")
    return True

In [93]:
json_to_midi(
    json_path="./samples/midi_notes.json",
    midi_path="./samples/output_pitch.mid"
)

Processing canon_partial.png with 55 elements
Created track for staff 1 with clef: clefG
Created track for staff 2 with clef: clefF
Processing 10 notes in 9 groups for staff_0
  Added note: E5 (MIDI 76) in group
  Added note: D5 (MIDI 74) in group
  Added note: C5 (MIDI 72) in group
  Added note: B4 (MIDI 71) in group
  Added note: A4 (MIDI 69) in group
  Added note: G4 (MIDI 67) in group
  Added note: A4 (MIDI 69) in group
  Added note: B4 (MIDI 71) in group
  Added note: C5 (MIDI 72) in group
  Added note: E5 (MIDI 76) in group
Processing 36 notes in 36 groups for staff_1
  Added note: C3 (MIDI 48) in group
  Added note: E3 (MIDI 52) in group
  Added note: G3 (MIDI 55) in group
  Added note: C4 (MIDI 60) in group
  Added note: G2 (MIDI 43) in group
  Added note: B2 (MIDI 47) in group
  Added note: D3 (MIDI 50) in group
  Added note: G3 (MIDI 55) in group
  Added note: A2 (MIDI 45) in group
  Added note: C3 (MIDI 48) in group
  Added note: E3 (MIDI 52) in group
  Added note: A3 (MIDI 

True

In [94]:
import pygame
import time

midi_file = "./samples/output_pitch.mid"
pygame.init()
pygame.mixer.init()
pygame.mixer.music.load(midi_file)
pygame.mixer.music.play()

while pygame.mixer.music.get_busy():
    time.sleep(0.5)
