# MIDI Transcription with SoundLab

This notebook demonstrates how to use SoundLab's audio-to-MIDI transcription powered by Basic Pitch.

**What you'll learn:**
- Transcribe audio to MIDI notes
- Configure transcription parameters
- Visualize piano rolls
- Work with note events
- Export MIDI files
- Best practices for transcription

## Setup

Import necessary modules and configure the environment.

In [None]:
import soundlab
from soundlab.transcription import MIDITranscriber, TranscriptionConfig
from soundlab.io import load_audio, save_midi, load_midi
from pathlib import Path

# For visualization and audio playback
from IPython.display import Audio, display
import matplotlib.pyplot as plt
import numpy as np

print(f"SoundLab version: {soundlab.__version__}")

## 1. Basic MIDI Transcription

Let's start by transcribing a simple melody to MIDI:

In [None]:
# Initialize the transcriber with default settings
transcriber = MIDITranscriber()

# Path to your audio file (monophonic melody works best)
# Replace with your own audio file or use test fixtures
input_file = "../../tests/fixtures/audio/sine_440hz_3s.wav"

# Listen to the input
print("Input audio:")
display(Audio(input_file))

# Transcribe to MIDI
print("\nTranscribing...")
result = transcriber.transcribe(input_file)

print(f"\nTranscription complete in {result.processing_time_seconds:.2f} seconds")
print(f"Notes detected: {result.note_count}")
print(f"Duration: {result.duration:.2f} seconds")

### Examining Note Events

The transcription result contains a list of `NoteEvent` objects with detailed information:

In [None]:
# Display first 10 notes
print("First 10 notes:\n")
for i, note in enumerate(result.notes[:10], 1):
    print(f"{i:2d}. {note.pitch_name:4s} | "
          f"Time: {note.start_time:6.3f}s - {note.end_time:6.3f}s | "
          f"Duration: {note.duration_ms:6.1f}ms | "
          f"Velocity: {note.velocity:3d} | "
          f"Freq: {note.frequency:6.1f}Hz | "
          f"Confidence: {note.confidence:.2f}")

if result.note_count > 10:
    print(f"\n... and {result.note_count - 10} more notes")

### Analyzing Note Statistics

In [None]:
if result.note_count > 0:
    # Get pitch range
    min_pitch, max_pitch = result.pitch_range
    min_note = result.notes[0]
    max_note = result.notes[0]
    
    for note in result.notes:
        if note.pitch == min_pitch:
            min_note = note
        if note.pitch == max_pitch:
            max_note = note
    
    print("Transcription Statistics:\n")
    print(f"Total notes: {result.note_count}")
    print(f"Duration: {result.duration:.2f}s")
    print(f"Pitch range: {min_note.pitch_name} ({min_note.frequency:.1f} Hz) to "
          f"{max_note.pitch_name} ({max_note.frequency:.1f} Hz)")
    print(f"Average velocity: {result.average_velocity:.1f}")
    
    # Calculate note density
    if result.duration > 0:
        note_density = result.note_count / result.duration
        print(f"Note density: {note_density:.1f} notes/second")

## 2. Visualizing the Piano Roll

Let's create a piano roll visualization of the transcribed notes:

In [None]:
def plot_piano_roll(midi_result, max_notes=None, figsize=(14, 6)):
    """
    Plot a piano roll visualization of MIDI notes.
    
    Args:
        midi_result: MIDIResult object
        max_notes: Maximum number of notes to display (None for all)
        figsize: Figure size tuple
    """
    notes = midi_result.notes[:max_notes] if max_notes else midi_result.notes
    
    if not notes:
        print("No notes to display")
        return
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot each note as a rectangle
    for note in notes:
        # Color by velocity
        color = plt.cm.viridis(note.velocity / 127.0)
        
        # Draw note rectangle
        rect = plt.Rectangle(
            (note.start_time, note.pitch),
            note.duration,
            0.8,
            facecolor=color,
            edgecolor='black',
            linewidth=0.5,
            alpha=0.8
        )
        ax.add_patch(rect)
    
    # Set labels and limits
    ax.set_xlabel('Time (seconds)', fontsize=12)
    ax.set_ylabel('MIDI Pitch', fontsize=12)
    ax.set_title('Piano Roll Visualization', fontsize=14, fontweight='bold')
    
    # Set y-axis to show all notes with some padding
    min_pitch, max_pitch = midi_result.pitch_range
    ax.set_ylim(min_pitch - 2, max_pitch + 2)
    ax.set_xlim(0, midi_result.duration)
    
    # Add grid
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # Add colorbar for velocity
    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, 
                                norm=plt.Normalize(vmin=0, vmax=127))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label('Velocity', fontsize=10)
    
    plt.tight_layout()
    plt.show()

# Visualize the transcription
plot_piano_roll(result)

## 3. Configuring Transcription Parameters

Fine-tune the transcription with custom parameters:

In [None]:
# Create custom configuration
custom_config = TranscriptionConfig(
    # Detection thresholds (0.1-0.9)
    onset_thresh=0.5,      # Higher = fewer false positives, may miss quiet notes
    frame_thresh=0.3,      # Higher = more confident note detection
    
    # Note filtering
    minimum_note_length=58.0,  # Minimum note duration in milliseconds (10-200)
    
    # Frequency range (Hz)
    minimum_frequency=32.7,    # C1 (20-500)
    maximum_frequency=2093.0,  # C7 (1000-8000)
    
    # Advanced options
    include_pitch_bends=False,  # Include pitch bend information
    melodia_trick=True,         # Use melodia trick for better monophonic results
    
    # Processing
    device="auto",  # "auto", "cuda", or "cpu"
)

print("Custom Transcription Configuration:\n")
print(f"  Onset threshold: {custom_config.onset_thresh}")
print(f"  Frame threshold: {custom_config.frame_thresh}")
print(f"  Min note length: {custom_config.minimum_note_length}ms")
print(f"  Frequency range: {custom_config.minimum_frequency:.1f} - {custom_config.maximum_frequency:.1f} Hz")
print(f"  Melodia trick: {custom_config.melodia_trick}")
print(f"  Device: {custom_config.device}")

### Comparing Different Configurations

Let's compare conservative vs. sensitive detection:

In [None]:
# Conservative: Fewer false positives, may miss quiet notes
conservative_config = TranscriptionConfig(
    onset_thresh=0.7,
    frame_thresh=0.5,
    minimum_note_length=100.0,
)

# Sensitive: More notes detected, may have false positives
sensitive_config = TranscriptionConfig(
    onset_thresh=0.3,
    frame_thresh=0.2,
    minimum_note_length=30.0,
)

print("Configuration Comparison:\n")
print("Conservative (fewer false positives):")
print(f"  Onset: {conservative_config.onset_thresh}, Frame: {conservative_config.frame_thresh}")
print(f"  Min length: {conservative_config.minimum_note_length}ms\n")

print("Sensitive (more notes detected):")
print(f"  Onset: {sensitive_config.onset_thresh}, Frame: {sensitive_config.frame_thresh}")
print(f"  Min length: {sensitive_config.minimum_note_length}ms")

# Uncomment to test different configurations:
# transcriber_conservative = MIDITranscriber(config=conservative_config)
# result_conservative = transcriber_conservative.transcribe(input_file)
# print(f"\nConservative notes: {result_conservative.note_count}")

# transcriber_sensitive = MIDITranscriber(config=sensitive_config)
# result_sensitive = transcriber_sensitive.transcribe(input_file)
# print(f"Sensitive notes: {result_sensitive.note_count}")

## 4. Saving MIDI Files

Export transcription results to standard MIDI format:

In [None]:
# Create output directory
output_dir = Path("./output/midi")
output_dir.mkdir(parents=True, exist_ok=True)

# Save MIDI file
midi_path = output_dir / "transcription.mid"

# The transcribe method can save directly
result_with_save = transcriber.transcribe(
    input_file,
    output_midi_path=str(midi_path)
)

print(f"MIDI file saved to: {midi_path}")
print(f"File size: {midi_path.stat().st_size / 1024:.1f} KB")

# You can also save an existing result
# save_midi(midi_path, result.notes)

### Loading MIDI Files

Load and inspect saved MIDI files:

In [None]:
# Load the MIDI file back
if midi_path.exists():
    loaded_notes = load_midi(midi_path)
    
    print(f"Loaded {len(loaded_notes)} notes from MIDI file\n")
    
    # Display first few notes
    print("First 5 notes:")
    for i, note in enumerate(loaded_notes[:5], 1):
        print(f"{i}. {note.pitch_name} at {note.start_time:.3f}s, "
              f"duration: {note.duration_ms:.1f}ms")

## 5. Working with Time Ranges

Extract notes from specific time ranges:

In [None]:
if result.note_count > 0:
    # Get notes from first 2 seconds
    first_section = result.get_notes_in_range(0.0, 2.0)
    print(f"Notes in first 2 seconds: {len(first_section)}")
    
    # Get notes from middle section
    mid_point = result.duration / 2
    middle_section = result.get_notes_in_range(mid_point - 1.0, mid_point + 1.0)
    print(f"Notes in middle 2 seconds: {len(middle_section)}")
    
    # Get notes from last second
    last_section = result.get_notes_in_range(result.duration - 1.0, result.duration)
    print(f"Notes in last second: {len(last_section)}")

## 6. Advanced Analysis

Perform more detailed analysis of the transcription:

In [None]:
def analyze_transcription(midi_result):
    """
    Perform detailed analysis of transcription result.
    """
    if midi_result.note_count == 0:
        print("No notes to analyze")
        return
    
    notes = midi_result.notes
    
    # Calculate statistics
    durations = [note.duration for note in notes]
    velocities = [note.velocity for note in notes]
    pitches = [note.pitch for note in notes]
    
    print("Detailed Transcription Analysis\n")
    print("=" * 50)
    
    # Duration statistics
    print("\nNote Durations:")
    print(f"  Average: {np.mean(durations):.3f}s ({np.mean(durations)*1000:.1f}ms)")
    print(f"  Median: {np.median(durations):.3f}s")
    print(f"  Min: {np.min(durations):.3f}s, Max: {np.max(durations):.3f}s")
    print(f"  Std dev: {np.std(durations):.3f}s")
    
    # Velocity statistics
    print("\nVelocities:")
    print(f"  Average: {np.mean(velocities):.1f}")
    print(f"  Range: {np.min(velocities)} - {np.max(velocities)}")
    print(f"  Std dev: {np.std(velocities):.1f}")
    
    # Pitch statistics
    print("\nPitch Distribution:")
    unique_pitches = len(set(pitches))
    print(f"  Unique pitches: {unique_pitches}")
    print(f"  Most common pitch: {max(set(pitches), key=pitches.count)} "
          f"({notes[pitches.index(max(set(pitches), key=pitches.count))].pitch_name})")
    
    # Calculate intervals between notes
    if len(notes) > 1:
        intervals = [notes[i+1].start_time - notes[i].end_time 
                    for i in range(len(notes)-1)]
        positive_intervals = [i for i in intervals if i > 0]
        
        print("\nNote Gaps:")
        if positive_intervals:
            print(f"  Average gap: {np.mean(positive_intervals):.3f}s")
            print(f"  Notes with gaps: {len(positive_intervals)} "
                  f"({100*len(positive_intervals)/len(intervals):.1f}%)")
        else:
            print("  No gaps between notes (legato)")
    
    # Confidence analysis
    confidences = [note.confidence for note in notes]
    print("\nConfidence Scores:")
    print(f"  Average: {np.mean(confidences):.3f}")
    print(f"  Min: {np.min(confidences):.3f}, Max: {np.max(confidences):.3f}")
    print(f"  High confidence (>0.8): {sum(1 for c in confidences if c > 0.8)} notes")
    print(f"  Low confidence (<0.5): {sum(1 for c in confidences if c < 0.5)} notes")

# Analyze the transcription
analyze_transcription(result)

## 7. Batch Transcription

Transcribe multiple audio files:

In [None]:
def batch_transcribe(input_dir, output_dir, config=None):
    """
    Transcribe all audio files in a directory to MIDI.
    
    Args:
        input_dir: Directory containing audio files
        output_dir: Directory for MIDI outputs
        config: Optional TranscriptionConfig
    """
    transcriber = MIDITranscriber(config=config)
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Find all audio files
    audio_extensions = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
    audio_files = []
    for ext in audio_extensions:
        audio_files.extend(input_path.glob(f'*{ext}'))
    
    print(f"Found {len(audio_files)} audio files\n")
    
    results = []
    for i, audio_file in enumerate(audio_files, 1):
        print(f"[{i}/{len(audio_files)}] Transcribing: {audio_file.name}")
        
        midi_file = output_path / f"{audio_file.stem}.mid"
        
        try:
            result = transcriber.transcribe(
                str(audio_file),
                output_midi_path=str(midi_file)
            )
            results.append((audio_file.name, result))
            print(f"  ✓ {result.note_count} notes in {result.processing_time_seconds:.1f}s\n")
        except Exception as e:
            print(f"  ✗ Error: {e}\n")
            continue
    
    # Summary
    print("\nBatch Transcription Summary:")
    print("=" * 50)
    total_notes = sum(result.note_count for _, result in results)
    total_time = sum(result.processing_time_seconds for _, result in results)
    print(f"Files processed: {len(results)}")
    print(f"Total notes: {total_notes}")
    print(f"Total processing time: {total_time:.1f}s")
    print(f"Average notes per file: {total_notes/len(results):.1f}")
    
    return results

# Example usage (uncomment to run):
# results = batch_transcribe(
#     input_dir="./audio_files",
#     output_dir="./output/midi_batch",
#     config=TranscriptionConfig(onset_thresh=0.5)
# )

print("Batch transcription function defined!")

## Best Practices and Tips

### 1. Audio Preparation

- **Best results**: Monophonic melodies (single note at a time)
- **Good results**: Polyphonic music with clear melody
- **Challenging**: Dense polyphonic music, percussion
- Clean, high-quality audio produces better transcriptions
- Consider extracting vocals stem first for vocal melodies

### 2. Parameter Tuning

**Onset Threshold (`onset_thresh`):**
- Lower (0.3): More notes, may include false positives
- Higher (0.7): Fewer notes, only confident detections
- Default (0.5): Good balance for most cases

**Frame Threshold (`frame_thresh`):**
- Controls note sustain detection
- Lower values capture longer notes better
- Higher values for staccato passages

**Minimum Note Length:**
- Increase to filter out very short notes
- Decrease for fast passages
- Default (58ms) works well for most music

### 3. Use Cases

- **Vocal melodies**: Extract vocals stem, use default settings
- **Piano**: Use sensitive settings, may need manual cleanup
- **Guitar**: Works best with clean tones
- **Synth leads**: Usually excellent results
- **Bass lines**: Consider limiting frequency range

### 4. Performance Optimization

- Enable GPU with `device="cuda"` for faster processing
- Process shorter segments for long files
- Use batch processing for multiple files

### 5. Common Issues

- **Too many notes**: Increase thresholds or minimum note length
- **Missing notes**: Lower thresholds, check frequency range
- **Wrong octave**: Check input audio sample rate
- **Choppy notes**: Lower `onset_thresh` and `frame_thresh`

### 6. Post-Processing

- Use a MIDI editor for final cleanup
- Quantize notes to musical grid if needed
- Adjust velocities for more musical expression
- Consider splitting into multiple tracks by pitch range

## Summary

In this notebook, you learned how to:

✓ Transcribe audio to MIDI notes  
✓ Work with note events and their properties  
✓ Visualize piano rolls  
✓ Configure transcription parameters  
✓ Compare different detection sensitivities  
✓ Save and load MIDI files  
✓ Extract notes from time ranges  
✓ Perform detailed transcription analysis  
✓ Batch process multiple files  
✓ Apply best practices for transcription  

**Next Steps:**
- Combine with stem separation for better results
- Learn audio analysis in notebook 03
- Apply effects processing in notebook 04