In [9]:
import os
import json
from ultralytics import YOLO

#1. Load Trained YOLO model
model = YOLO("./trained_models/yolov8m_500ep.pt")

#2. Set sample image path
source_dir = "./canon_samples/"
image_files = [os.path.join(source_dir, f) for f in os.listdir(source_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

#3. YOLO predict sample image
results = model(image_files)

#4. Define JSON file save path
save_json = "./canon_samples/midi_notes.json"

#5. Save pitch parse relevant code
output = []

for result in results:
    img_path = result.path
    boxes = result.boxes.data
    confs = result.boxes.conf
    cls_ids = result.boxes.cls

    notes = []
    for box, conf, cls_id in zip(boxes, confs, cls_ids):
        label = result.names[int(cls_id)]
        x1, y1, x2, y2 = map(float, box[:4])
        note_data = {
            "label": label,
            "confidence": float(conf),
            "bbox": [x1, y1, x2, y2],
            "center": [(x1 + x2) / 2, (y1 + y2) / 2]
        }
        notes.append(note_data)

    output.append({
        "filename": os.path.basename(img_path),
        "notes": notes
    })

#6. Save result as a JSON file
with open(save_json, 'w') as f:
    json.dump(output, f, indent=2)

print(f"🎼 Done! MIDI-relevant predictions saved to {save_json}")



0: 192x992 1 brace, 1 clefG, 1 clefF, 4 timeSig4s, 15 noteheadBlackOnLines, 21 noteheadBlackInSpaces, 4 noteheadWholeOnLines, 6 noteheadWholeInSpaces, 2 staffs, 12.8ms
Speed: 2.6ms preprocess, 12.8ms inference, 1.9ms postprocess per image at shape (1, 3, 192, 992)
🎼 Done! MIDI-relevant predictions saved to ./canon_samples/midi_notes.json


In [None]:
import json
import mido
from mido import Message, MidiFile, MidiTrack
import os
import math

def c_major_transform_clefG(i):
    dic = {
        6: 81,
        5: 79,
        4: 77,
        3: 76,
        2: 74,
        1: 72,
        0: 71,
        -1: 69,
        -2: 67,
        -3: 65,
        -4: 64,
        -5: 62,
        -6: 60
    }
    return dic.get(i, 60)

def c_major_transform_clefF(i):
    dic = {
        6: 60,
        5: 59,
        4: 57,
        3: 55,
        2: 53,
        1: 52,
        0: 50,
        -1: 48,
        -2: 47,
        -3: 45,
        -4: 43,
        -5: 41,
        -6: 40
    }
    return dic.get(i, 60)


def get_note_name(midi_note):
    """Returns the note name for a MIDI note number (e.g., 60 → C4)"""
    note_names = {
        0: 'C', 1: 'C#', 2: 'D', 3: 'D#', 4: 'E', 5: 'F',
        6: 'F#', 7: 'G', 8: 'G#', 9: 'A', 10: 'A#', 11: 'B'
    }
    octave = midi_note // 12 - 1  # MIDI note 60 is C4
    note = midi_note % 12
    return f"{note_names[note]}{octave}"

def get_duration_from_symbol(symbol_type, has_beam=False):
    """
    Returns the MIDI duration in ticks for a given note or rest type
    
    Args:
        symbol_type: The label of the symbol (e.g., 'noteheadBlack')
        has_beam: Whether the note has a beam attached, indicating eighth note
    """
    if "Whole" in symbol_type:
        return 1920  # 4 beats * 480 ticks
    elif "Half" in symbol_type:
        return 960   # 2 beats * 480 ticks
    elif "Black" in symbol_type and has_beam:
        return 240   # 0.5 beat * 480 ticks (eighth note)
    elif "Black" in symbol_type:
        return 480   # 1 beat * 480 ticks (quarter note)
    elif "8th" in symbol_type or "Eighth" in symbol_type:
        return 240   # 0.5 beat * 480 ticks
    elif "16th" in symbol_type:
        return 120   # 0.25 beat * 480 ticks
    else:
        # Default to quarter note
        return 480

def assign_elements_to_staves(elements, staves):
    """
    Assign elements (notes, beams, flags, etc.) to their appropriate staff
    
    Args:
        elements: List of elements with bbox and center coordinates
        staves: List of staff elements with bbox coordinates
    
    Returns:
        Dictionary mapping staff indices to lists of elements belonging to that staff
    """
    # Sort staves by vertical position (top to bottom)
    sorted_staves = sorted(staves, key=lambda x: x.get("center", [0, 0])[1])
    
    # Initialize result dictionary
    staff_elements = {i: [] for i in range(len(sorted_staves))}
    
    # For each element, find the closest staff
    for element in elements:
        element_y = element.get("center", [0, 0])[1]
        
        # Find the closest staff by vertical distance
        closest_staff_idx = None
        min_distance = float('inf')
        
        for i, staff in enumerate(sorted_staves):
            staff_y = staff.get("center", [0, 0])[1]
            distance = abs(element_y - staff_y)
            
            if distance < min_distance:
                min_distance = distance
                closest_staff_idx = i
        
        # Assign the element to the closest staff
        if closest_staff_idx is not None:
            staff_elements[closest_staff_idx].append(element)
    
    return staff_elements

def is_note_beamed(note, beams):
    """
    Check if a note is connected to a beam within the same staff
    
    Args:
        note: The note object with bbox and center
        beams: List of beam objects with bbox and center for the same staff
    
    Returns:
        Boolean indicating if the note is connected to a beam
    """
    note_x = note.get("center", [0, 0])[0]
    note_bbox = note.get("bbox", [0, 0, 0, 0])
    threshold = note_bbox[3] - note_bbox[1]
    # Only check horizontal alignment with each beam (vertical is handled by staff assignment)
    for beam in beams:
        beam_bbox = beam.get("bbox", [0, 0, 0, 0])
        beam_x1, beam_y1, beam_x2, beam_y2 = beam_bbox
        
        # Check if note's x-coordinate is within the beam's x-range
        if beam_x1 - threshold <= note_x <= beam_x2 + threshold:  # Small tolerance
            return True
    
    return False

def is_note_flagged(note, flags):
    """
    Check if a note has a flag (8th, 16th, etc.) within the same staff
    
    Args:
        note: The note object with bbox and center
        flags: List of flag objects with bbox and center for the same staff
    
    Returns:
        Boolean indicating if the note has a flag
    """
    note_x = note.get("center", [0, 0])[0]
    note_bbox = note.get("bbox", [0, 0, 0, 0])
    threshold = note_bbox[3] - note_bbox[1]
    
    # Check proximity to each flag (only horizontal alignment matters)
    for flag in flags:
        flag_x = flag.get("center", [0, 0])[0]
        
        # Check horizontal proximity
        if abs(flag_x - note_x) < threshold:
            return True
    
    return False

In [None]:

def json_to_midi(json_path, midi_path):
    # Read the JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Create a new MIDI file with standard resolution
    mid = MidiFile(ticks_per_beat=480)
    
    # Create tracks based on detected staves
    tracks = {}
    track_info = {}  # Store clef and other info for each track
    
    # Process each page in the JSON data
    for page in data:
        filename = page.get("filename", "unknown")
        notes = page.get("notes", [])
        
        print(f"Processing {filename} with {len(notes)} elements")
        
        # Extract all elements by type
        staves = [note for note in notes if note.get("label") == "staff"]
        beams = [note for note in notes if note.get("label") == "beam"]
        flags = [note for note in notes if "flag" in note.get("label", "").lower()]
        clefs = [note for note in notes if note.get("label") in ["clefG", "clefF", "clefC"]]
        time_sigs = [note for note in notes if note.get("label", "").startswith("timeSig")]
        
        # Assign elements to staves
        staff_beams = assign_elements_to_staves(beams, staves)
        staff_flags = assign_elements_to_staves(flags, staves)
        staff_clefs = assign_elements_to_staves(clefs, staves)
        
        print(f"Found {len(staves)} staves, {len(beams)} beams, and {len(flags)} flags")
        
        # Sort staves by vertical position (top to bottom)
        sorted_staves = sorted(staves, key=lambda x: x.get("center", [0, 0])[1])
        
        # Create tracks for each staff
        for i, staff in enumerate(sorted_staves):
            staff_id = f"staff_{i}"
            staff_center_y = staff.get("center", [0, 0])[1]
            
            # Get clefs for this staff
            staff_clef_candidates = staff_clefs.get(i, [])
            staff_clef = None
            if staff_clef_candidates:
                # Use the first clef found for this staff
                staff_clef = staff_clef_candidates[0].get("label")
            
            # Create track if it doesn't exist
            if staff_id not in tracks:
                tracks[staff_id] = MidiTrack()
                mid.tracks.append(tracks[staff_id])
                # Add track name
                tracks[staff_id].append(mido.MetaMessage('track_name', name=f"Staff {i+1}"))
                # Default to piano sound (General MIDI program 0)
                tracks[staff_id].append(mido.Message('program_change', program=0, time=0))
                
                # Store staff information
                track_info[staff_id] = {
                    "clef": staff_clef,
                    "staff_y": staff_center_y,
                    "bbox": staff.get("bbox"),
                    "time_signature": (4, 4)  # Default 4/4 time
                }
                
                print(f"Created track for staff {i+1} with clef: {staff_clef}")
    
        # Assign all musical elements to staves
        all_musical_elements = [note for note in notes if "notehead" in note.get("label", "") or "rest" in note.get("label", "")]
        staff_musical_elements = assign_elements_to_staves(all_musical_elements, staves)
        
        # Process musical elements for each staff
        for staff_idx, staff_elements in staff_musical_elements.items():
            staff_id = f"staff_{staff_idx}"
            if staff_id not in track_info:
                continue  # Skip if we don't have track info
                
            staff_data = track_info[staff_id]
            
            # Get beams and flags for this staff
            this_staff_beams = staff_beams.get(staff_idx, [])
            this_staff_flags = staff_flags.get(staff_idx, [])
            
            # Filter notes and rests
            notes_and_rests = []
            for element in staff_elements:
                element_label = element.get("label", "")
                
                # Check if the element is a note or a rest
                is_note = "notehead" in element_label
                is_rest = "rest" in element_label
                
                if is_note or is_rest:
                    # Check if the note is beamed or flagged (for eighth notes)
                    if is_note and "Black" in element_label:
                        has_beam = is_note_beamed(element, this_staff_beams)
                        has_flag = is_note_flagged(element, this_staff_flags)
                        
                        # Store this information with the note
                        element['has_beam'] = has_beam
                        element['has_flag'] = has_flag
                        
                        if has_beam:
                            print(f"  Note at ({element.get('center', [0, 0])[0]:.1f}, {element.get('center', [0, 0])[1]:.1f}) is beamed")
                        if has_flag:
                            print(f"  Note at ({element.get('center', [0, 0])[0]:.1f}, {element.get('center', [0, 0])[1]:.1f}) has a flag")
                    
                    notes_and_rests.append(element)
            
            # Sort by horizontal position (left to right)
            notes_and_rests.sort(key=lambda x: x.get("center", [0, 0])[0])
            
            # Group into temporal events (chords or single notes/rests)
            horizontal_tolerance = 10  # Pixels of tolerance for horizontal alignment
            
            # Group symbols into temporal segments
            temporal_groups = []
            current_group = []
            
            for symbol in notes_and_rests:
                x_pos = symbol.get("center", [0, 0])[0]
                
                if not current_group:
                    # First symbol in a group
                    current_group.append(symbol)
                elif abs(x_pos - current_group[0].get("center", [0, 0])[0]) <= horizontal_tolerance:
                    # Symbol horizontally aligned with current group
                    current_group.append(symbol)
                else:
                    # New horizontal position, start a new group
                    temporal_groups.append(current_group)
                    current_group = [symbol]
            
            # Add the last group if it exists
            if current_group:
                temporal_groups.append(current_group)
            
            print(f"Processing {len(notes_and_rests)} symbols in {len(temporal_groups)} groups for {staff_id}")
            
            # Process each group (chord, single note, or rest)
            current_time = 0  # Accumulated time in ticks
            
            for group in temporal_groups:
                # Determine if this group contains notes, rests, or both
                notes_in_group = [s for s in group if "notehead" in s.get("label", "")]
                rests_in_group = [s for s in group if "rest" in s.get("label", "")]
                
                # If the group contains both notes and rests, prioritize notes
                if notes_in_group:
                    # This is a chord or a single note
                    # Sort notes by vertical position (low to high)
                    notes_in_group.sort(key=lambda x: x.get("center", [0, 0])[1], reverse=True)
                    
                    # Determine duration based on the first note in the group
                    note_type = notes_in_group[0].get("label", "")
                    has_beam = notes_in_group[0].get("has_beam", False)
                    has_flag = notes_in_group[0].get("has_flag", False)
                    duration = get_duration_from_symbol(note_type, has_beam or has_flag)
                    
                    # Process each note in the chord
                    chord_pitches = []
                    for note in notes_in_group:
                        note_center_y = note.get("center", [0, 0])[1]
                        
                        # Calculate pitch based on vertical position relative to staff
                        staff_bbox = staff_data["bbox"]
                        line_height = (staff_bbox[3] - staff_bbox[1]) / 4  # 5 lines = 4 spaces
                        distance_from_center = (staff_data["staff_y"] - note_center_y) / (line_height/2)
                        
                        # Round to nearest staff position
                        staff_position = round(distance_from_center)
                        
                        # Determine pitch based on clef
                        if staff_data["clef"] == "clefG":
                            pitch = c_major_transform_clefG(staff_position)
                        elif staff_data["clef"] == "clefF":
                            pitch = c_major_transform_clefF(staff_position)
                        else:
                            # Default handling for unknown clef
                            pitch = 60 + staff_position  # Middle C + offset
                        
                        # Skip if note is outside reasonable MIDI range
                        if 0 <= pitch <= 127:
                            chord_pitches.append(pitch)
                            duration_type = "quarter note"
                            if has_beam or has_flag:
                                duration_type = "eighth note"
                            print(f"  Added note: {get_note_name(pitch)} (MIDI {pitch}) ({duration_type})")
                    
                    # Play all notes in the chord simultaneously
                    velocity = 64  # Default medium velocity
                    
                    # Add note_on events with the accumulated time for the first note
                    first_note = True
                    for pitch in chord_pitches:
                        if first_note:
                            tracks[staff_id].append(Message('note_on', note=pitch, velocity=velocity, time=current_time))
                            first_note = False
                        else:
                            tracks[staff_id].append(Message('note_on', note=pitch, velocity=velocity, time=0))
                    
                    # Reset current_time since we've used it
                    current_time = 0
                    
                    # Add note_off events
                    # First note_off has the full duration, others immediately follow
                    for i, pitch in enumerate(chord_pitches):
                        if i == 0:
                            tracks[staff_id].append(Message('note_off', note=pitch, velocity=0, time=duration))
                        else:
                            tracks[staff_id].append(Message('note_off', note=pitch, velocity=0, time=0))
                    
                elif rests_in_group:
                    # This is a rest - just add a time delay
                    rest_type = rests_in_group[0].get("label", "")
                    rest_duration = get_duration_from_symbol(rest_type)
                    
                    print(f"  Added rest: {rest_type} with duration {rest_duration}")
                    
                    # Accumulate time for the rest
                    current_time += rest_duration
    
    # Check if tracks are empty and create a default one if needed
    if not tracks:
        print("No valid tracks found, creating default track")
        default_track = MidiTrack()
        mid.tracks.append(default_track)
        default_track.append(mido.MetaMessage('track_name', name="Default Track"))
        default_track.append(mido.Message('program_change', program=0, time=0))
    
    # Save the MIDI file
    mid.save(midi_path)
    print(f"MIDI file saved to {midi_path}")
    return True

In [11]:
json_to_midi(
    json_path="./canon_samples/midi_notes.json",
    midi_path="./canon_samples/output_pitch.mid"
)

{'filename': 'pair1.png', 'notes': [{'label': 'staff', 'confidence': 0.9698613286018372, 'bbox': [18.59806251525879, 19.029415130615234, 774.5816040039062, 47.28985595703125], 'center': [396.5898332595825, 33.15963554382324]}, {'label': 'clefG', 'confidence': 0.9628314971923828, 'bbox': [24.083467483520508, 8.088375091552734, 40.81251907348633, 57.00587844848633], 'center': [32.44799327850342, 32.54712677001953]}, {'label': 'clefF', 'confidence': 0.9568300247192383, 'bbox': [23.65345001220703, 92.07353210449219, 41.29993438720703, 113.18228149414062], 'center': [32.47669219970703, 102.6279067993164]}, {'label': 'staff', 'confidence': 0.9553788900375366, 'bbox': [17.80859375, 91.9994888305664, 775.2376708984375, 120.0471420288086], 'center': [396.52313232421875, 106.0233154296875]}, {'label': 'noteheadBlackOnLine', 'confidence': 0.9367202520370483, 'bbox': [325.5791320800781, 116.30025482177734, 334.3127136230469, 123.43281555175781], 'center': [329.9459228515625, 119.86653518676758]}, 

True

In [8]:
import pygame
import time

midi_file = "./samples/output_pitch.mid"
pygame.init()
pygame.mixer.init()
pygame.mixer.music.load(midi_file)
pygame.mixer.music.play()

while pygame.mixer.music.get_busy():
    time.sleep(0.5)


pygame 2.6.1 (SDL 2.28.4, Python 3.10.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


error: XDG_RUNTIME_DIR not set in the environment.


error: ALSA: Couldn't open audio device: Host is down