In [None]:
!pip install piano_transcription_inference pretty_midi librosa gradio torch numpy


Collecting piano_transcription_inference
  Downloading piano_transcription_inference-0.0.5-py3-none-any.whl.metadata (2.9 kB)
Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio
  Downloading gradio-5.7.1-py3-none-any.whl.metadata (16 kB)
Collecting mido (from piano_transcription_inference)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Collecting torchlibrosa (from piano_transcription_inference)
  Downloading torchlibrosa-0.1.0-py3-none-any.whl.metadata (3.5 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadat

In [None]:
import os
import time
import torch
import librosa
import numpy as np
import pretty_midi
from piano_transcription_inference import PianoTranscription, sample_rate, load_audio
import gradio as gr
import matplotlib.pyplot as plt


In [None]:
class AMTProcessor:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Using device: {self.device}")
        self.transcriptor = PianoTranscription(device=self.device)
        self.sample_rate = sample_rate

    def process_audio(self, audio_path):
        """
        Process audio file and return MIDI path
        """
        # Create output directory if it doesn't exist
        output_dir = os.path.join('/tmp', f'amt_output_{int(time.time())}')
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")

        # Load and process audio
        print("Loading audio file...")
        audio, _ = librosa.load(audio_path, sr=self.sample_rate, mono=True)

        # Define output MIDI path
        midi_path = os.path.join(output_dir, 'transcribed_output.mid')

        # Transcribe audio to MIDI
        try:
            print("Transcribing audio to MIDI...")
            transcribed_dict = self.transcriptor.transcribe(audio, midi_path)
            print(f"Successfully transcribed audio to: {midi_path}")
            return midi_path, transcribed_dict
        except Exception as e:
            print(f"Error during transcription: {str(e)}")
            return None, None

    def create_piano_roll(self, midi_path):
        """
        Create piano roll visualization from MIDI file
        """
        try:
            print("Creating piano roll visualization...")
            midi_data = pretty_midi.PrettyMIDI(midi_path)
            piano_roll = midi_data.get_piano_roll()
            return piano_roll
        except Exception as e:
            print(f"Error creating piano roll: {str(e)}")
            return None

In [None]:
def process_and_visualize(audio_file):
    """
    Main processing function for Gradio interface
    """
    print("Starting audio processing...")

    # Initialize processor
    processor = AMTProcessor()

    # Process audio file - Pass the path directly since it's already a string
    midi_path, transcribed_dict = processor.process_audio(audio_file)  # Remove .name

    if midi_path is None:
        return "Error during transcription", None

    # Create piano roll visualization
    piano_roll = processor.create_piano_roll(midi_path)

    if piano_roll is None:
        return midi_path, "Error creating visualization"

    # Create visualization plot
    plt.figure(figsize=(12, 8))
    plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='hot')
    plt.title('Piano Roll Representation')
    plt.xlabel('Time (frames)')
    plt.ylabel('MIDI Note Number')
    plt.colorbar(label='Velocity')

    # Return both MIDI file and visualization data
    return midi_path, plt.gcf()

In [None]:
# Cell 5 - Create and Launch Gradio Interface
iface = gr.Interface(
    fn=process_and_visualize,
    inputs=[
        gr.Audio(label="Upload Piano Recording", sources=["upload"], type="filepath")  # Changed source to sources
    ],
    outputs=[
        gr.File(label="Download MIDI File"),
        gr.Plot(label="Piano Roll Visualization")
    ],
    title="AI Piano Tutor - Audio to MIDI Transcription",
    description="""
    Upload a piano recording to convert it into a MIDI file and visualize the transcription.
    The model uses high-resolution piano transcription to accurately capture timing and dynamics.
    """,
    examples=[],
    cache_examples=False,
    theme="default"
)

# Launch the interface
iface.launch(debug=True)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2024/12/01 19:05:44 [W] [service.go:132] login to server failed: dial tcp 44.237.78.176:7000: connect: connection refused


<IPython.core.display.Javascript object>