In [24]:
import json
import os 
import note_seq
import librosa
import torchaudio
import numpy as np
import torch
import torch.nn.functional as F
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar, MutableMapping
import dataclasses

from preprocessor.preprocessor import *

In [44]:
MAESTRO_PATH = "/import/c4dm-datasets/maestro-v3.0.0/"

def get_metadata():
    with open(os.path.join(MAESTRO_PATH, "maestro-v3.0.0.json"), "r") as f:
        metadata = json.load(f)
    return metadata

def get_midi(id, metadata=get_metadata()):
    filename = metadata["midi_filename"][id]
    filename = os.path.join(MAESTRO_PATH, filename)
    with open(filename, 'rb') as f:
        content = f.read()
        ns = note_seq.midi_to_note_sequence(content)
    return ns 

def get_audio(id, metadata=get_metadata(), sr=16000):
    filename = metadata["audio_filename"][id]
    filename = os.path.join(MAESTRO_PATH, filename)
    print(filename)
    samples, sample_rate = librosa.load(filename, sr=sr)
    return samples, sample_rate 

def get_frame_times(id, metadata=get_metadata(), frame_rate=50):
    filename = metadata["audio_filename"][id]
    filename = os.path.join(MAESTRO_PATH, filename)
    meta = torchaudio.info(filename)
    num_frames = np.ceil(meta.num_frames / meta.sample_rate * frame_rate)
    return torch.arange(num_frames) / frame_rate

def _audio_to_frames_pytorch(
    samples: torch.Tensor,
    hop_size: int,
    frame_rate: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
  """Convert audio samples to non-overlapping frames and frame times."""
  frame_size = hop_size
  samples = F.pad(samples,
                  [0, frame_size - len(samples) % frame_size],
                  mode='constant')

  # Split audio into frames.
  frames = samples.unfold(0, frame_size, frame_size) 

  num_frames = len(samples) // frame_size
  times = torch.arange(num_frames) / frame_rate
  return frames, times

In [45]:
meta = get_metadata()
frame_times = get_frame_times("1275", metadata=meta, frame_rate=50)

In [47]:
frame_times.shape

torch.Size([31571])

In [37]:
meta = torchaudio.info("/import/c4dm-datasets/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_04_R1_2004_01-02_ORIG_MID--AUDIO_04_R1_2004_02_Track02_wav.wav")
total_time = meta.num_frames / meta.sample_rate 
num_frames = np.ceil(total_time / 0.02).astype(int)


In [38]:
num_frames

31571

In [23]:
librosa.time_to_samples(ns.total_time, 16000)

  librosa.time_to_samples(ns.total_time, 16000)


10082916

In [7]:
frames, frame_times = _audio_to_frames_pytorch(torch.Tensor(samples), 320, 50)

In [8]:
frame_times.shape

torch.Size([31571])

In [20]:
num_samples, samples.shape

(10085440.0, (10102555,))

In [22]:
ns = get_midi("1275")
#ns = note_seq.apply_sustain_control_changes(ns)
num_samples = ns.total_time * 16000
num_samples += 320 - num_samples % 320
num_frames = num_samples // 320
print(num_frames)

31510.0


In [7]:
times, values = note_sequence_to_onsets_and_offsets_and_programs(ns)

In [15]:
general_codec = vocabularies.build_codec(100, 10)

In [16]:
(events, event_start_indices, event_end_indices,
state_events, state_event_indices) = encode_and_index_events(NoteEncodingState(), 
     times, values, note_event_data_to_events,
     general_codec, frame_times, note_encoding_state_to_events)

In [76]:
seg_start_idx, seg_end_idx, seg_state_idx = split_tokens([
    torch.Tensor(event_start_indices), 
    torch.Tensor(event_end_indices), torch.Tensor(state_event_indices)])

In [91]:
seg_start_idx = seg_start_idx.int()
seg_end_idx = seg_end_idx.int()
seg_state_idx = seg_state_idx.int()

In [110]:
seg_state_idx.shape

torch.Size([124, 256])

In [88]:
seg_state_idx[0, 0], seg_end_idx[0, -1]

(tensor(0, dtype=torch.int32), tensor(638, dtype=torch.int32))

In [118]:
tie_end_token = general_codec.encode_event(event_codec.Event("tie", 0))
event_segments = [
    extract_sequence_with_indices(torch.Tensor(events), 
                                  seg_start_idx[i, 0], seg_end_idx[i, -1], 
                                  torch.Tensor(state_events), tie_end_token, 
                                  seg_state_idx[i]) 
    for i in range(124)]

In [99]:
events[event_start_indices[3 * 256]:event_end_indices[4 * 256]].shape

(766,)

In [104]:
event_segment

tensor([1.1310e+03, 1.0000e+00, 1.1320e+03, 1.1300e+03, 1.0680e+03, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.1320e+03, 1.1290e+03, 1.0630e+03,
        1.1320e+03, 1.1300e+03, 1.0630e+03, 1.0000e+00, 1.1320e+03, 1.1300e+03,
        1.0730e+03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.1320e+03,
        1.1300e+03, 1.0720e+03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.1320e+03, 1.1290e+03, 1.0680e+03, 1.1320e+03,
        1.1300e+03, 1.0680e+03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.1320e+03, 1.1290e+03, 1.0560e+03, 1.1320e+03, 1.1290e+03,
        1.0630e+03, 1.1320e+03, 1.1290e+03, 1.0680e+03, 1.1320e+03, 1.1290e+03,
        1.0730e+03, 1.1320e+03, 1.1290e+03, 1.0750e+03, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.1320e+

In [97]:
event_segment.shape

torch.Size([765])

In [73]:
general_codec.decode_event_index(1132)

Event(type='program', value=0)

In [63]:
124 * 256, event_start_indices.shape

(31744, (31571,))

In [126]:
def run_length_encode_shifts(
      features
  ) -> Mapping[str, Any]:
    """Combine leading/interior shifts, trim trailing shifts.

    Args:
      features: Dict of features to process.

    Returns:
      A dict of features.
    """
    events = features["inputs"]

    shift_steps = 0
    total_shift_steps = 0
    output = tf.constant([], dtype=tf.int32)

    current_state = tf.zeros(0, dtype=tf.int32)

    for event in events:
      # Let autograph know that the shape of 'output' will change during the
      # loop.
      tf.autograph.experimental.set_loop_options(
          shape_invariants=[(output, tf.TensorShape([None]))])
      if general_codec.is_shift_event_index(event):
        shift_steps += 1
        total_shift_steps += 1

      else:
        # If this event is a state change and has the same value as the current
        # state, we can skip it entirely.
        is_redundant = False
        for i, (min_index, max_index) in enumerate([]):
          if (min_index <= event) and (event <= max_index):
            if current_state[i] == event:
              is_redundant = True
            current_state = tf.tensor_scatter_nd_update(
                current_state, indices=[[i]], updates=[event])
        if is_redundant:
          continue

        # Once we've reached a non-shift event, RLE all previous shift events
        # before outputting the non-shift event.
        if shift_steps > 0:
          shift_steps = total_shift_steps
          while shift_steps > 0:
            output_steps = tf.minimum(general_codec.max_shift_steps, shift_steps)
            output = tf.concat([output, [output_steps]], axis=0)
            shift_steps -= output_steps
        output = tf.concat([output, [event]], axis=0)

    features["inputs"] = output
    return features

In [127]:
run_length_encode_shifts({"inputs": events})

In [125]:
event_start_indices

array([    0,     2,     4, ..., 95692, 95694, 95696])

In [120]:
frame_times[256]

5.12

In [98]:
events.shape, event_start_indices.shape, event_end_indices.shape

((95697,), (31571,), (31571,))

In [99]:
event_start_indices[:10], event_end_indices[:10]

(array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18]),
 array([ 2,  4,  6,  8, 10, 12, 14, 16, 18, 20]))

In [107]:
general_codec.decode_event_index(1129)

Event(type='velocity', value=0)

In [112]:
values[100], [times[i] for i in idx][100]

(NoteEventData(pitch=60, velocity=62, program=0, is_drum=False, instrument=None),
 10.090625)

In [101]:
idx = np.argsort(times, kind='stable')
values = [values[i] for i in idx]
[round(times[i] * general_codec.steps_per_second) for i in idx]

[100,
 102,
 130,
 155,
 196,
 205,
 206,
 206,
 217,
 218,
 225,
 226,
 238,
 239,
 239,
 251,
 254,
 262,
 274,
 276,
 276,
 282,
 303,
 306,
 345,
 357,
 408,
 408,
 408,
 408,
 431,
 453,
 453,
 453,
 455,
 457,
 480,
 500,
 500,
 500,
 503,
 503,
 526,
 545,
 545,
 545,
 550,
 550,
 573,
 598,
 598,
 598,
 598,
 598,
 602,
 604,
 608,
 611,
 615,
 619,
 622,
 642,
 645,
 648,
 650,
 653,
 658,
 661,
 667,
 699,
 703,
 705,
 753,
 756,
 803,
 805,
 805,
 807,
 816,
 819,
 826,
 839,
 850,
 850,
 850,
 854,
 855,
 861,
 880,
 880,
 882,
 885,
 905,
 906,
 952,
 956,
 980,
 1007,
 1007,
 1009,
 1009,
 1009,
 1022,
 1029,
 1031,
 1040,
 1047,
 1049,
 1050,
 1067,
 1073,
 1075,
 1086,
 1090,
 1098,
 1105,
 1105,
 1111,
 1111,
 1113,
 1113,
 1114,
 1116,
 1118,
 1121,
 1123,
 1132,
 1143,
 1153,
 1153,
 1169,
 1169,
 1169,
 1169,
 1172,
 1172,
 1224,
 1231,
 1231,
 1234,
 1295,
 1297,
 1305,
 1377,
 1379,
 1379,
 1384,
 1406,
 1410,
 1429,
 1433,
 1455,
 1459,
 1479,
 1480,
 1491,
 1491

In [58]:
events = note_event_data_to_events(None, values[0], general_codec)
encoded_events = []
for e in events:
    encoded_events.append(general_codec.encode_event(e))
events, encoded_events

([Event(type='program', value=0),
  Event(type='velocity', value=0),
  Event(type='pitch', value=36)],
 [1132, 1129, 1037])