Skip to content

Commit

Permalink
allow truncating performances by time when extracting (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
iansimon committed Feb 23, 2018
1 parent 949f1ec commit 7240ec1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
12 changes: 10 additions & 2 deletions magenta/music/performance_lib.py
Expand Up @@ -745,7 +745,8 @@ def performance_pitch_histogram_sequence(performance, window_size_seconds,

def extract_performances(
quantized_sequence, start_step=0, min_events_discard=None,
max_events_truncate=None, num_velocity_bins=0, split_instruments=False):
max_events_truncate=None, max_steps_truncate=None, num_velocity_bins=0,
split_instruments=False):
"""Extracts one or more performances from the given quantized NoteSequence.
Args:
Expand All @@ -755,6 +756,8 @@ def extract_performances(
discarded.
max_events_truncate: Maximum length of tracks in events. Longer tracks are
truncated.
max_steps_truncate: Maximum length of tracks in quantized time steps. Longer
tracks are truncated.
num_velocity_bins: Number of velocity bins to use. If 0, velocity events
will not be included at all.
split_instruments: If True, will extract a performance for each instrument.
Expand All @@ -769,7 +772,7 @@ def extract_performances(

stats = dict([(stat_name, statistics.Counter(stat_name)) for stat_name in
['performances_discarded_too_short',
'performances_truncated',
'performances_truncated', 'performances_truncated_timewise',
'performances_discarded_more_than_1_program']])

if sequences_lib.is_absolute_quantized_sequence(quantized_sequence):
Expand Down Expand Up @@ -811,6 +814,11 @@ def extract_performances(
num_velocity_bins=num_velocity_bins,
instrument=instrument)

if (max_steps_truncate is not None and
performance.num_steps > max_steps_truncate):
performance.set_length(max_steps_truncate)
stats['performances_truncated_timewise'].increment()

if (max_events_truncate is not None and
len(performance) > max_events_truncate):
performance.truncate(max_events_truncate)
Expand Down
10 changes: 10 additions & 0 deletions magenta/music/performance_lib_test.py
Expand Up @@ -433,6 +433,11 @@ def testExtractPerformances(self):
self.assertEqual(1, len(perfs))
self.assertEqual(3, len(perfs[0]))

perfs, _ = performance_lib.extract_performances(
quantized_sequence, max_steps_truncate=100)
self.assertEqual(1, len(perfs))
self.assertEqual(100, perfs[0].num_steps)

def testExtractPerformancesMultiProgram(self):
testing_lib.add_track_to_sequence(
self.note_sequence, 0,
Expand Down Expand Up @@ -480,6 +485,11 @@ def testExtractPerformancesRelativeQuantized(self):
self.assertEqual(1, len(perfs))
self.assertEqual(3, len(perfs[0]))

perfs, _ = performance_lib.extract_performances(
quantized_sequence, max_steps_truncate=100)
self.assertEqual(1, len(perfs))
self.assertEqual(100, perfs[0].num_steps)

def testExtractPerformancesSplitInstruments(self):
testing_lib.add_track_to_sequence(
self.note_sequence, 0, [(60, 100, 0.0, 4.0)])
Expand Down

0 comments on commit 7240ec1

Please sign in to comment.