Skip to content

Commit

Permalink
do not save overlapping pitch bends to midi file by default
Browse files Browse the repository at this point in the history
  • Loading branch information
jvbalen committed May 12, 2022
1 parent 5e043d7 commit cf15da3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
6 changes: 6 additions & 0 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def predict(
minimum_note_length: float = 58,
minimum_frequency: Optional[float] = None,
maximum_frequency: Optional[float] = None,
multiple_pitch_bends: bool = False,
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
) -> Tuple[Dict[str, np.array], pretty_midi.PrettyMIDI, List[Tuple[float, float, int, float, Optional[List[int]]]]]:
Expand All @@ -274,6 +275,7 @@ def predict(
minimum_note_length: The minimum allowed note length in frames.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
Returns:
Expand All @@ -299,6 +301,7 @@ def predict(
min_note_len=min_note_len, # convert to frames
min_freq=minimum_frequency,
max_freq=maximum_frequency,
multiple_pitch_bends=multiple_pitch_bends,
melodia_trick=melodia_trick,
)

Expand Down Expand Up @@ -342,6 +345,7 @@ def predict_and_save(
minimum_note_length: float = 58,
minimum_frequency: Optional[float] = None,
maximum_frequency: Optional[float] = None,
multiple_pitch_bends: bool = False,
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
) -> None:
Expand All @@ -360,6 +364,7 @@ def predict_and_save(
minimum_note_length: The minimum allowed note length in frames.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
"""
Expand All @@ -376,6 +381,7 @@ def predict_and_save(
minimum_note_length,
minimum_frequency,
maximum_frequency,
multiple_pitch_bends,
melodia_trick,
debug_file,
)
Expand Down
49 changes: 37 additions & 12 deletions basic_pitch/note_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import astuple, dataclass
import pathlib
from collections import defaultdict, namedtuple
from typing import Dict, List, Optional, Tuple, Union
import mir_eval
import librosa
Expand Down Expand Up @@ -51,6 +53,7 @@ def model_output_to_notes(
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
include_pitch_bends: bool = True,
multiple_pitch_bends: bool = False,
melodia_trick: bool = True,
) -> Tuple[pretty_midi.PrettyMIDI, List[Tuple[float, float, int, float, Optional[List[int]]]]]:
"""Convert model output to MIDI
Expand All @@ -68,6 +71,8 @@ def model_output_to_notes(
min_note_len: The minimum allowed note length in frames.
min_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
max_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
include_pitch_bends: If True, include pitch bends.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
Returns:
Expand Down Expand Up @@ -99,7 +104,7 @@ def model_output_to_notes(
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
]

return note_events_to_midi(estimated_notes_time_seconds), estimated_notes_time_seconds
return note_events_to_midi(estimated_notes_time_seconds, multiple_pitch_bends), estimated_notes_time_seconds


def sonify_midi(midi: pretty_midi.PrettyMIDI, save_path: Union[pathlib.Path, str]) -> None:
Expand Down Expand Up @@ -205,32 +210,37 @@ def get_pitch_bends(


def note_events_to_midi(
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]]
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
multiple_pitch_bends: bool = False,
) -> pretty_midi.PrettyMIDI:
"""Create a pretty_midi object from note events
Args:
note_events : list of tuples [(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
where amplitude is a number between 0 and 1
save_path : path to save midi file. If None, no midi file is saved
multiple_pitch_bends : If True, allow overlapping notes to have pitch bends
Note: this will assign each pitch to its own midi instrument, as midi does not yet
support per-note pitch bends
Returns:
pretty_midi.PrettyMIDI() object
"""
mid = pretty_midi.PrettyMIDI()
if not multiple_pitch_bends:
note_events_with_pitch_bends = drop_overlapping_pitch_bends(note_events_with_pitch_bends)

piano_program = pretty_midi.instrument_name_to_program("Electric Piano 1")
instrument_per_piano_key = {i: pretty_midi.Instrument(program=piano_program) for i in range(21, 109)}
instruments = defaultdict(lambda: pretty_midi.Instrument(program=piano_program))
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
instrument = instruments[note_number] if multiple_pitch_bends else instruments[0]
note = pretty_midi.Note(
velocity=int(np.round(127 * amplitude)),
pitch=note_number,
start=start_time,
end=end_time,
)

instrument_per_piano_key[note_number].notes.append(note)
instrument.notes.append(note)
if not pitch_bend:
continue
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
Expand All @@ -240,16 +250,31 @@ def note_events_to_midi(
pitch_bend_midi_ticks[pitch_bend_midi_ticks > N_PITCH_BEND_TICKS - 1] = N_PITCH_BEND_TICKS - 1
pitch_bend_midi_ticks[pitch_bend_midi_ticks < -N_PITCH_BEND_TICKS] = -N_PITCH_BEND_TICKS
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend_midi_ticks):
instrument_per_piano_key[note_number].pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))

for inst in instrument_per_piano_key.values():
# only add instruments with notes
if len(inst.notes) > 0:
mid.instruments.append(inst)
instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))
mid.instruments.extend(instruments.values())

return mid


def drop_overlapping_pitch_bends(
note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]]
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
"""Drop pitch bends from any notes that overlap in time with another note
TODO: naive N^2 implementation!
"""
note_events = list(map(list, note_events_with_pitch_bends)) # tuple to (mutable) list
for i, note in enumerate(note_events):
start_time, end_time = note[:2]
for other_note in note_events[:i]:
other_start_time, other_end_time = other_note[:2]
if start_time < other_end_time and end_time > other_start_time:
note[-1] = None # last field is pitch_bend
other_note[-1] = None

return list(map(tuple, note_events))


def get_infered_onsets(onsets: np.array, frames: np.array, n_diff: int = 2) -> np.array:
"""Infer onsets from large changes in frame amplitudes.
Expand Down
9 changes: 8 additions & 1 deletion basic_pitch/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,14 @@ def main() -> None:
default=None,
help="The maximum allowed note frequency, in Hz.",
)
parser.add_argument(
"--multiple-pitch-bends",
type=bool,
action="store_true",
help="Allow overlapping notes in midi file to have pitch bends. Note: this will map each "
"pitch to its own instrument",
)
parser.add_argument("--debug-file", default=None, help="Optional file for debug output for inference.")
#
parser.add_argument("--no-melodia", default=False, action="store_true", help="Skip the melodia trick.")
args = parser.parse_args()

Expand Down Expand Up @@ -121,6 +127,7 @@ def main() -> None:
args.minimum_note_length,
args.minimum_frequency,
args.maximum_frequency,
args.multiple_pitch_bends,
not args.no_melodia,
pathlib.Path(args.debug_file) if args.debug_file else None,
)
Expand Down

0 comments on commit cf15da3

Please sign in to comment.