Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ def get_audio_input(
yield np.expand_dims(window, axis=0), window_time, original_length


def get_audio_input_from_array(
audio_original: npt.NDArray[np.float32], overlap_len: int, hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
"""
A version of get_audio_input that works on an in-memory numpy array.
"""
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"

original_length = audio_original.shape[0]
audio_padded = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
for window, window_time in window_audio_file(audio_padded, hop_size):
yield np.expand_dims(window, axis=0), window_time, original_length


def unwrap_output(
output: npt.NDArray[np.float32],
audio_original_length: int,
Expand Down Expand Up @@ -272,14 +286,14 @@ def unwrap_output(


def run_inference(
audio_path: Union[pathlib.Path, str],
audio_input: Union[pathlib.Path, str, npt.NDArray[np.float32]],
model_or_model_path: Union[Model, pathlib.Path, str],
debug_file: Optional[pathlib.Path] = None,
) -> Dict[str, np.array]:
"""Run the model on the input audio path.
"""Run the model on the input audio path or numpy array.

Args:
audio_path: The audio to run inference on.
audio_input: The audio to run inference on, can be a file path or a numpy array.
model_or_model_path: A loaded Model or path to a serialized model to load.
debug_file: An optional path to output debug data to. Useful for testing/verification.

Expand All @@ -297,14 +311,21 @@ def run_inference(
hop_size = AUDIO_N_SAMPLES - overlap_len

output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):

# Choose the correct generator based on input type
if isinstance(audio_input, (pathlib.Path, str)):
audio_generator = get_audio_input(audio_input, overlap_len, hop_size)
else: # It's a numpy array
audio_generator = get_audio_input_from_array(audio_input, overlap_len, hop_size)

for audio_windowed, _, audio_original_length in audio_generator:
for k, v in model.predict(audio_windowed).items():
output[k].append(v)

unwrapped_output = {
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
}

if debug_file:
with open(debug_file, "w") as f:
json.dump(
Expand All @@ -317,7 +338,7 @@ def run_inference(
},
f,
)

return unwrapped_output


Expand Down Expand Up @@ -428,6 +449,7 @@ def predict(
minimum_frequency: Optional[float] = None,
maximum_frequency: Optional[float] = None,
multiple_pitch_bends: bool = False,
infer_onsets: bool = True,
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
midi_tempo: float = DEFAULT_MINIMUM_MIDI_TEMPO,
Expand All @@ -444,24 +466,44 @@ def predict(
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
minimum_note_length: The minimum allowed note length in milliseconds.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
midi_tempo: The tempo for the output midi file.
Returns:
The model output, midi data and note events from a single prediction
"""

with no_tf_warnings():
print(f"Predicting MIDI for {audio_path}...")

model_output = run_inference(audio_path, model_or_model_path, debug_file)
# --- Simplified Workflow ---
# 1. Load the entire audio file into memory.
print("Loading audio file into memory...")
y, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
audio_duration = len(y) / AUDIO_SAMPLE_RATE
print(f"Audio loaded. Duration: {audio_duration:.2f} seconds.")

# 2. Add a robust padding to the end of the audio.
# A longer padding ensures the CNN has enough context at the end of the audio stream.
padding_duration_s = 20.0
padding = np.zeros(int(AUDIO_SAMPLE_RATE * padding_duration_s), dtype=np.float32)
audio_to_process = np.concatenate([y, padding])

# 3. Run inference on the padded audio.
print("Running inference...")
model_output = run_inference(audio_to_process, model_or_model_path, debug_file)

# 4. Convert model output to notes.
min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
midi_data, note_events = infer.model_output_to_notes(
model_output,
onset_thresh=onset_threshold,
frame_thresh=frame_threshold,
infer_onsets=infer_onsets,
min_note_len=min_note_len, # convert to frames
min_freq=minimum_frequency,
max_freq=maximum_frequency,
Expand All @@ -470,6 +512,7 @@ def predict(
midi_tempo=midi_tempo,
)

# Write the aggregated results after processing all chunks
if debug_file:
with open(debug_file) as f:
debug_data = json.load(f)
Expand Down