In [350]:
from preprocessor.preprocessor import tokenize

import dataclasses
import note_seq
import numpy as np
import torch
from typing import Any, Callable, MutableMapping, Optional, Sequence, Tuple, TypeVar, MutableSet
import torch.nn.functional as F

from preprocessor import vocabularies
from preprocessor.event_codec import Codec, Event
from preprocessor.preprocessor import *
import pandas as pd
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [213]:
filename = "/import/c4dm-datasets/maestro-v3.0.0/2013/ORIG-MIDI_02_7_6_13_Group__MID--AUDIO_06_R1_2013_wav--4.midi"
codec = vocabularies.build_codec(100, 256/50)
ns = read_midi(filename)
ns = note_seq.apply_sustain_control_changes(ns)
num_frames = np.ceil(ns.total_time * 50)
frame_times = torch.arange(num_frames) / 50 
times, values = note_sequence_to_onsets_and_offsets_and_programs(ns)

In [328]:
d = np.array(times)
d_idx = np.nonzero(np.logical_and(d > 59 * 5.12, d < 60 * 5.12))
e = np.array(values)[d_idx]

In [214]:
def quantize_time(times, frame_length):
    steps = np.round(np.array(times) / frame_length)
    return steps.astype(int)

In [244]:
quantized_time = quantize_time(times, 0.01)
time_stamps = np.unique(quantized_time) 

In [343]:
events = {}
state_events = {0: [codec.encode_event(Event("tie", 0))]}
ds = NoteEncodingState()
num_segments = np.ceil(time_stamps.max() / 512).astype(int)
for time in time_stamps:
    segment_num = time // 512
    event_idx = np.nonzero(quantized_time == time)[0]
    event_values = [values[i] for i in event_idx]
    event = events.get(segment_num, [])
    tokens = [codec.encode_event(token) for event in event_values for token in note_event_data_to_events(ds, event, codec)]
    tokens.insert(0, time % 512)
    event.extend(tokens)
    events[segment_num] = event
    state_event = [codec.encode_event(s) for s in note_encoding_state_to_events(ds)]
    state_events[segment_num+1] = state_event

In [345]:
results = torch.zeros(len(events), 2048).int()
for k, v in events.items():
    all_events = state_events[k] + v
    results[k][:len(all_events)] = torch.Tensor(all_events).int()

In [354]:

def decode_events(
    state: DS,
    tokens: np.ndarray,
    start_time: int,
    max_time: Optional[int],
    codec: Codec,
    decode_event_fn: Callable[[DS, float, Event, Codec],
                              None],
) -> Tuple[int, int]:
  """Decode a series of tokens, maintaining a decoding state object.

  Args:
    state: Decoding state object; will be modified in-place.
    tokens: event tokens to convert.
    start_time: offset start time if decoding in the middle of a sequence.
    max_time: Events at or beyond this time will be dropped.
    codec: An event_codec.Codec object that maps indices to Event objects.
    decode_event_fn: Function that consumes an Event (and the current time) and
        updates the decoding state.

  Returns:
    invalid_events: number of events that could not be decoded.
    dropped_events: number of events dropped due to max_time restriction.
  """
  invalid_events = 0
  dropped_events = 0
  cur_steps = 0
  cur_time = start_time
  token_idx = 0
  for token_idx, token in enumerate(tokens):
    try:
      event = codec.decode_event_index(token)
    except ValueError:
      invalid_events += 1
      continue
    if event.type == 'shift':
      cur_steps += event.value
      cur_time = start_time + cur_steps / codec.steps_per_second
      if max_time and cur_time > max_time:
        dropped_events = len(tokens) - token_idx
        break
    else:
      cur_steps = 0
      try:
        decode_event_fn(state, cur_time, event, codec)
      except ValueError:
        invalid_events += 1
        print(
            f"Got invalid event when decoding event {event} at time {cur_time}. Invalid event counter now at {invalid_events}.")
        continue
  return invalid_events, dropped_events

In [351]:
def _add_note_to_sequence(
    ns: note_seq.NoteSequence,
    start_time: float, end_time: float, pitch: int, velocity: int,
    program: int = 0, is_drum: bool = False
) -> None:
  end_time = max(end_time, start_time + 0.01)
  ns.notes.add(
      start_time=start_time, end_time=end_time,
      pitch=pitch, velocity=velocity, program=program, is_drum=is_drum)
  ns.total_time = max(ns.total_time, end_time)

In [352]:
@dataclasses.dataclass
class NoteDecodingState:
  """Decoding state for note transcription."""
  current_time: float = 0.0
  # velocity to apply to subsequent pitch events (zero for note-off)
  current_velocity: int = 100 
  # program to apply to subsequent pitch events
  current_program: int = 0
  # onset time and velocity for active pitches and programs
  active_pitches: MutableMapping[Tuple[int, int],
                                 Tuple[float, int]] = dataclasses.field(
                                     default_factory=dict)
  # pitches (with programs) to continue from previous segment
  tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field(
      default_factory=set)
  # whether or not we are in the tie section at the beginning of a segment
  is_tie_section: bool = False
  # partially-decoded NoteSequence
  note_sequence: note_seq.NoteSequence = dataclasses.field(
      default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220))

def decode_note_event(
    state: NoteDecodingState,
    time: float,
    event: Event,
    codec: Codec
) -> None:
  """Process note event and update decoding state."""
  if time < state.current_time:
    raise ValueError('event time < current time, %f < %f' % (
        time, state.current_time))
  state.current_time = time
  if event.type == 'pitch':
    pitch = event.value
    if state.is_tie_section:
      # "tied" pitch
      if (pitch, state.current_program) not in state.active_pitches:
        raise ValueError('inactive pitch/program in tie section: %d/%d' %
                         (pitch, state.current_program))
      if (pitch, state.current_program) in state.tied_pitches:
        raise ValueError('pitch/program is already tied: %d/%d' %
                         (pitch, state.current_program))
      state.tied_pitches.add((pitch, state.current_program))
    elif state.current_velocity == 0:
      # note offset
      if (pitch, state.current_program) not in state.active_pitches:
        raise ValueError('note-off for inactive pitch/program: %d/%d' %
                         (pitch, state.current_program))
      onset_time, onset_velocity = state.active_pitches.pop(
          (pitch, state.current_program))
      _add_note_to_sequence(
          state.note_sequence, start_time=onset_time, end_time=time,
          pitch=pitch, velocity=onset_velocity, program=state.current_program)
    else:
      # note onset
      if (pitch, state.current_program) in state.active_pitches:
        # The pitch is already active; this shouldn't really happen but we'll
        # try to handle it gracefully by ending the previous note and starting a
        # new one.
        onset_time, onset_velocity = state.active_pitches.pop(
            (pitch, state.current_program))
        _add_note_to_sequence(
            state.note_sequence, start_time=onset_time, end_time=time,
            pitch=pitch, velocity=onset_velocity, program=state.current_program)
      state.active_pitches[(pitch, state.current_program)] = (
          time, state.current_velocity)
  elif event.type == 'drum':
    # drum onset (drums have no offset)
    if state.current_velocity == 0:
      raise ValueError('velocity cannot be zero for drum event')
    offset_time = time + 0.01
    _add_note_to_sequence(
        state.note_sequence, start_time=time, end_time=offset_time,
        pitch=event.value, velocity=state.current_velocity, is_drum=True)
  elif event.type == 'velocity':
    # velocity change
    num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
    velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins)
    state.current_velocity = velocity
  elif event.type == 'program':
    # program change
    state.current_program = event.value
  elif event.type == 'tie':
    # end of tie section; end active notes that weren't declared tied
    if not state.is_tie_section:
      raise ValueError('tie section end event when not in tie section')
    for (pitch, program) in list(state.active_pitches.keys()):
      if (pitch, program) not in state.tied_pitches:
        onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
        _add_note_to_sequence(
            state.note_sequence,
            start_time=onset_time, end_time=state.current_time,
            pitch=pitch, velocity=onset_velocity, program=program)
    state.is_tie_section = False
  else:
    raise ValueError('unexpected event type: %s' % event.type)



In [355]:
decoding_state = NoteDecodingState()
invalid_ids, dropped_events = decode_events(
    state=decoding_state, tokens=results[0], start_time=0, max_time=None,
    codec=codec, decode_event_fn=decode_note_event)

Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='tie', value=tensor(0.)) 0 1
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(68.)) tensor(2.3800) 2
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(59.)) tensor(2.3900) 3
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(64.)) tensor(2.3900) 4
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(64.)) tensor(2.8100) 5
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(68.)) tensor(2.8100) 6
Got invalid event when decoding event %s at time %f. Invalid event counter now at %d. Event(type='pitch', value=tensor(59.)) tensor(2.8200) 7
Got invalid event when

In [None]:
644, 572, 644, 577, 644, 580, 644, 584, 644, 589, 644, 592, 644, 596, 643

In [276]:
(events_, event_start_indices, event_end_indices, state_events_, state_event_indices, ) = encode_and_index_events(
    NoteEncodingState(), times, values, note_event_data_to_events, codec, frame_times, note_encoding_state_to_events)

In [277]:
seg_start_idx, seg_end_idx, seg_state_idx, segment_times = split_tokens(
    [event_start_indices, event_end_indices, state_event_indices, frame_times],
    segment_length=256,
)

In [279]:
state_events_

tensor([643, 644, 565,  ..., 644, 553, 643], dtype=torch.int32)

In [289]:
state_events_

tensor([644, 572, 644, 577, 643, 644, 553, 644, 572, 644, 577, 643, 644, 553,
        644, 577, 643, 644, 553, 643], dtype=torch.int32)

In [290]:
seg_state_idx[-1]

tensor([104416, 104416, 104416, 104416, 104416, 104416, 104416, 104416, 104416,
        104416, 104416, 104416, 104416, 104416, 104440, 104440, 104476, 104489,
        104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489,
        104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489,
        104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489, 104489,
        104489, 104652, 104757, 104772, 104772, 104772, 104772, 104772, 104772,
        104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772,
        104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772,
        104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772, 104772,
        104772, 104772, 104772, 104772, 104772, 104772, 104852, 104852, 104852,
        104852, 104852, 104852, 104852, 104852, 104877, 104901, 104901, 104961,
        104964, 104964, 104973, 105013, 105013, 105013, 105028, 105028, 105028,
        105028, 105028, 105028, 105028, 

In [318]:
id = 59
c = extract_sequence_with_indices(events_, seg_start_idx[id, 0], seg_end_idx[id, -1], 
        torch.Tensor(state_events_), codec.encode_event(Event("tie", 0)), seg_state_idx[id])[:40]
count_shift_and_pad(c, 200, codec)

tensor([644., 548., 644., 560., 644., 571., 643.,   3., 644., 641., 571.,   4.,
        644., 641., 548.,   5., 644., 642., 572.,  15., 644., 642., 573.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   

In [299]:
550 in c

True

In [259]:
seg_start_idx[0]

tensor([  0,   2,   4,   6,   8,  10,  12,  14,  16,  18,  20,  22,  24,  26,
         28,  30,  32,  34,  36,  38,  40,  42,  44,  46,  48,  50,  52,  54,
         56,  58,  60,  62,  64,  66,  68,  70,  72,  74,  76,  78,  80,  82,
         84,  86,  88,  90,  92,  94,  96,  98, 103, 108, 110, 112, 114, 116,
        118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138, 140, 142, 144,
        146, 148, 150, 152, 154, 156, 158, 160, 162, 164, 166, 168, 170, 172,
        174, 176, 178, 180, 188, 193, 195, 197, 199, 201, 203, 205, 207, 209,
        211, 213, 215, 217, 219, 221, 223, 225, 227, 229, 231, 233, 235, 237,
        239, 241, 243, 245, 247, 249, 251, 253, 273, 275, 277, 279, 281, 283,
        285, 287, 289, 291, 293, 295, 297, 299, 301, 303, 305, 307, 309, 311,
        313, 327, 335, 337, 339, 341, 343, 345, 347, 349, 351, 353, 355, 357,
        359, 361, 363, 365, 367, 369, 371, 373, 375, 383, 397, 399, 401, 403,
        405, 407, 409, 411, 413, 415, 417, 419, 421, 423, 425, 4

In [146]:
tie_end_token = codec.encode_event(Event("tie", 0))
extract_sequence_with_indices(events, seg_end_idx[2, 0], seg_end_idx[2, -1], 
        torch.Tensor(state_events), tie_end_token, seg_state_idx[2])

tensor([644, 548, 644, 560, 644, 567, 644, 572, 644, 577, 643,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1, 644, 642, 562,
        644, 642, 571, 644, 642, 550,   1,   1,   1,   1, 644, 641, 548, 644,
        641, 560, 644, 641, 567, 644, 641, 572, 644, 641, 577,   1, 644, 642,
        567,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1, 644, 641,
        567,   1,   1,   1, 644, 641, 571,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1, 644, 642, 571,
          1,   1,   1,   1,   1, 644, 642, 567,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1, 644, 6

In [102]:
event_seg

tensor([  1,   1, 644,  ..., 596,   1,   1], dtype=torch.int32)

In [6]:
original_result = count_shift_and_pad(event_seg, 2048, codec)

In [7]:
torch_result = count_shift_and_pad_torch(event_seg, 2048, codec)

In [8]:
(original_result == torch_result).all()

tensor(True)