In [1]:
import os
import json
import pretty_midi
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter

In [2]:
# Configuration
MIDI_FOLDER = "data"
OUTPUT_FOLDER = "raw_data"

In [3]:
print("🎵 MIDI Dataset Creation Pipeline")
print("=" * 50)

🎵 MIDI Dataset Creation Pipeline


In [4]:
# Create output directory
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

In [5]:
# Discover MIDI files
midi_folder = Path(MIDI_FOLDER)
midi_files = list(midi_folder.glob("*.mid")) + list(midi_folder.glob("*.midi"))

In [6]:
print(f"📁 Found {len(midi_files)} MIDI files in '{MIDI_FOLDER}' folder")
if midi_files:
    print("Files to process:")
    for i, file in enumerate(midi_files[:10], 1):  # Show first 10
        print(f"  {i}. {file.name}")
    if len(midi_files) > 10:
        print(f"  ... and {len(midi_files) - 10} more files")
else:
    print("⚠️  No MIDI files found! Please check your data folder.")

📁 Found 13 MIDI files in 'data' folder
Files to process:
  1. ducktalesremasteredbosstheme_by_sb.mid
  2. SM-StudiopolisZoneAct2.mid
  3. SM-MirageSaloonZoneAct1K.mid
  4. SM-OilOceanZoneAct2.mid
  5. SM-FlyingBatteryZoneAct1.mid
  6. SM-MirageSaloonZoneAct1ST.mid
  7. SM-PressGardenZoneAct2.mid
  8. EggReverie.mid
  9. SM-MirageSaloonZoneAct2.mid
  10. SM-PressGardenZoneAct1.mid
  ... and 3 more files


In [7]:
def extract_notes_from_instrument(instrument):
    """Extract notes from a MIDI instrument as tuples."""
    notes_data = []
    for note in instrument.notes:
        note_tuple = [
            note.pitch,
            round(note.start, 2),
            round(note.end, 2),
            note.velocity
        ]
        notes_data.append(note_tuple)
    return notes_data

In [8]:
def process_single_midi(midi_file_path):
    """Process a single MIDI file and return instrument data."""
    try:
        midi_data = pretty_midi.PrettyMIDI(str(midi_file_path))
        
        file_stats = {
            'filename': midi_file_path.name,
            'instruments': [],
            'total_notes': 0,
            'duration': midi_data.get_end_time()
        }
        
        for instrument in midi_data.instruments:
            notes = extract_notes_from_instrument(instrument)
            
            instrument_info = {
                'program': instrument.program,
                'is_drum': instrument.is_drum,
                'notes': notes,
                'note_count': len(notes)
            }
            
            file_stats['instruments'].append(instrument_info)
            file_stats['total_notes'] += len(notes)
        
        return file_stats, True
    except Exception as e:
        print(f"❌ Error processing {midi_file_path.name}: {str(e)}")
        return None, False

In [9]:
print("\n🔄 Processing MIDI files...")
print("-" * 30)
# Track dataset statistics
dataset_stats = {
    'files_processed': 0,
    'files_failed': 0,
    'total_instruments': 0,
    'unique_programs': set(),
    'total_notes': 0,
    'instrument_data': defaultdict(list)  # program -> list of all notes
}
# Process each file
for midi_file in midi_files:
    print(f"Processing: {midi_file.name}")
    
    file_stats, success = process_single_midi(midi_file)
    
    if success:
        dataset_stats['files_processed'] += 1
        dataset_stats['total_notes'] += file_stats['total_notes']
        
        # Process each instrument in the file
        for instrument_info in file_stats['instruments']:
            program = instrument_info['program']
            notes = instrument_info['notes']
            
            # Add to our dataset
            dataset_stats['instrument_data'][program].extend(notes)
            dataset_stats['unique_programs'].add(program)
            
            print(f"  → Program {program}: {len(notes)} notes")
        
        dataset_stats['total_instruments'] += len(file_stats['instruments'])
    else:
        dataset_stats['files_failed'] += 1


🔄 Processing MIDI files...
------------------------------
Processing: ducktalesremasteredbosstheme_by_sb.mid
  → Program 36: 910 notes
  → Program 61: 700 notes
  → Program 61: 324 notes
  → Program 65: 424 notes
  → Program 30: 910 notes
  → Program 58: 910 notes
  → Program 81: 148 notes
  → Program 18: 348 notes
  → Program 18: 348 notes
  → Program 16: 293 notes
  → Program 16: 347 notes
  → Program 16: 375 notes
  → Program 16: 37 notes
  → Program 16: 523 notes
  → Program 16: 116 notes
Processing: SM-StudiopolisZoneAct2.mid
  → Program 61: 866 notes
  → Program 80: 31 notes
  → Program 55: 294 notes
  → Program 48: 314 notes
  → Program 2: 1492 notes
  → Program 81: 318 notes
  → Program 88: 156 notes
  → Program 89: 348 notes
  → Program 89: 348 notes
  → Program 33: 430 notes
  → Program 24: 1393 notes
Processing: SM-MirageSaloonZoneAct1K.mid




  → Program 24: 109 notes
  → Program 28: 86 notes
  → Program 105: 73 notes
  → Program 27: 112 notes
  → Program 22: 100 notes
  → Program 59: 101 notes
  → Program 59: 48 notes
  → Program 2: 152 notes
  → Program 78: 24 notes
  → Program 81: 46 notes
  → Program 3: 152 notes
  → Program 61: 108 notes
  → Program 18: 52 notes
  → Program 47: 32 notes
  → Program 33: 189 notes
  → Program 24: 696 notes
Processing: SM-OilOceanZoneAct2.mid
  → Program 29: 304 notes
  → Program 30: 76 notes
  → Program 90: 608 notes
  → Program 81: 232 notes
  → Program 74: 94 notes
  → Program 15: 116 notes
  → Program 111: 80 notes
  → Program 48: 90 notes
  → Program 4: 72 notes
  → Program 4: 72 notes
  → Program 38: 428 notes
  → Program 24: 1885 notes
Processing: SM-FlyingBatteryZoneAct1.mid
  → Program 62: 320 notes
  → Program 80: 188 notes
  → Program 17: 952 notes
  → Program 29: 204 notes
  → Program 61: 380 notes
  → Program 87: 76 notes
  → Program 50: 24 notes
  → Program 11: 96 notes
  → 

In [10]:
print(f"\n💾 Saving dataset to JSON files...")
print("-" * 30)
for program, notes_list in dataset_stats['instrument_data'].items():
    json_filename = f"{program}.json"
    json_filepath = os.path.join(OUTPUT_FOLDER, json_filename)
    
    # Save to JSON
    with open(json_filepath, 'w') as f:
        json.dump(notes_list, f, indent=2)
    
    print(f"✅ Saved {len(notes_list)} notes for program {program} → {json_filename}")


💾 Saving dataset to JSON files...
------------------------------
✅ Saved 1232 notes for program 36 → 36.json
✅ Saved 4922 notes for program 61 → 61.json
✅ Saved 424 notes for program 65 → 65.json
✅ Saved 1435 notes for program 30 → 30.json
✅ Saved 910 notes for program 58 → 58.json
✅ Saved 2566 notes for program 81 → 81.json
✅ Saved 1094 notes for program 18 → 18.json
✅ Saved 5362 notes for program 16 → 16.json
✅ Saved 1511 notes for program 80 → 80.json
✅ Saved 334 notes for program 55 → 55.json
✅ Saved 1550 notes for program 48 → 48.json
✅ Saved 4468 notes for program 2 → 2.json
✅ Saved 540 notes for program 88 → 88.json
✅ Saved 1900 notes for program 89 → 89.json
✅ Saved 5203 notes for program 33 → 33.json
✅ Saved 17111 notes for program 24 → 24.json
✅ Saved 506 notes for program 28 → 28.json
✅ Saved 73 notes for program 105 → 105.json
✅ Saved 112 notes for program 27 → 27.json
✅ Saved 100 notes for program 22 → 22.json
✅ Saved 149 notes for program 59 → 59.json
✅ Saved 156 notes f

In [11]:
print(f"\n📊 Dataset Creation Summary")
print("=" * 50)
print(f"Files processed successfully: {dataset_stats['files_processed']}")
print(f"Files failed: {dataset_stats['files_failed']}")
print(f"Total unique instruments: {len(dataset_stats['unique_programs'])}")
print(f"Total notes in dataset: {dataset_stats['total_notes']}")


📊 Dataset Creation Summary
Files processed successfully: 13
Files failed: 0
Total unique instruments: 56
Total notes in dataset: 67769


In [12]:
# Instrument distribution
program_counts = {prog: len(notes) for prog, notes in dataset_stats['instrument_data'].items()}
sorted_programs = sorted(program_counts.items(), key=lambda x: x[1], reverse=True)

print(f"\n🎹 Instrument Distribution (Top 10):")
for i, (program, count) in enumerate(sorted_programs[:10], 1):
    print(f"  {i:2d}. Program {program:3d}: {count:,} notes")


🎹 Instrument Distribution (Top 10):
   1. Program  24: 17,111 notes
   2. Program  16: 5,362 notes
   3. Program  33: 5,203 notes
   4. Program  61: 4,922 notes
   5. Program   2: 4,468 notes
   6. Program  38: 3,633 notes
   7. Program  81: 2,566 notes
   8. Program  89: 1,900 notes
   9. Program   0: 1,724 notes
  10. Program  90: 1,624 notes


In [13]:
def analyze_dataset_quality():
    """Analyze the quality and characteristics of our dataset."""
    print(f"\n🔍 Dataset Quality Analysis")
    print("-" * 30)
    
    # Load and analyze each instrument file
    for program in sorted(dataset_stats['unique_programs']):
        json_filepath = os.path.join(OUTPUT_FOLDER, f"{program}.json")
        
        with open(json_filepath, 'r') as f:
            notes = json.load(f)
        
        if not notes:
            continue
            
        # Convert to DataFrame for analysis
        df = pd.DataFrame(notes, columns=['pitch', 'start', 'end', 'velocity'])
        
        # Basic statistics
        duration_stats = (df['end'] - df['start']).describe()
        pitch_range = (df['pitch'].min(), df['pitch'].max())
        velocity_range = (df['velocity'].min(), df['velocity'].max())
        
        print(f"Program {program}:")
        print(f"  Notes: {len(notes)}")
        print(f"  Pitch range: {pitch_range[0]}-{pitch_range[1]}")
        print(f"  Velocity range: {velocity_range[0]}-{velocity_range[1]}")
        print(f"  Avg note duration: {duration_stats['mean']:.2f}s")
        print()

In [14]:
# Run quality analysis
analyze_dataset_quality()


🔍 Dataset Quality Analysis
------------------------------
Program 0:
  Notes: 1724
  Pitch range: 28-86
  Velocity range: 100-100
  Avg note duration: 0.35s

Program 1:
  Notes: 236
  Pitch range: 60-104
  Velocity range: 100-100
  Avg note duration: 0.15s

Program 2:
  Notes: 4468
  Pitch range: 52-76
  Velocity range: 100-100
  Avg note duration: 0.33s

Program 3:
  Notes: 400
  Pitch range: 62-89
  Velocity range: 75-100
  Avg note duration: 0.14s

Program 4:
  Notes: 226
  Pitch range: 55-77
  Velocity range: 100-100
  Avg note duration: 1.65s

Program 7:
  Notes: 698
  Pitch range: 54-69
  Velocity range: 100-100
  Avg note duration: 0.08s

Program 9:
  Notes: 68
  Pitch range: 62-86
  Velocity range: 100-100
  Avg note duration: 0.54s

Program 10:
  Notes: 90
  Pitch range: 67-92
  Velocity range: 100-100
  Avg note duration: 0.66s

Program 11:
  Notes: 96
  Pitch range: 57-66
  Velocity range: 100-100
  Avg note duration: 1.00s

Program 12:
  Notes: 460
  Pitch range: 65-76
  V

In [15]:
print("🎉 Dataset creation complete!")
print(f"📁 Dataset files saved in: {OUTPUT_FOLDER}/")
print("🔬 Ready for research and analysis!")

🎉 Dataset creation complete!
📁 Dataset files saved in: raw_data/
🔬 Ready for research and analysis!
