# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
!pip install dm-haiku
!pip install pretty_midi
!pip install optax
!pip install basic_pitch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pretty_midi
  Downloading pretty_midi-0.2.9.tar.gz (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 5.4 MB/s 
Collecting mido>=1.1.16
  Downloading mido-1.2.10-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 7.6 MB/s 
Building wheels for collected packages: pretty-midi
  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty-midi: filename=pretty_midi-0.2.9-py3-none-any.whl size=5591954 sha256=0f0c1097f95cfce63465c81536c74c7896b90bb6ee2922554d8dbd1b601bb22a
  Stored in directory: /root/.cache/pip/wheels/2a/5a/e3/30eeb9a99350f3f7e21258fcb132743eef1a4f49b3505e76b6
Successfully built pretty-midi
Installing collected packages: mido, pretty-midi
Successfully installed mido-1.2.10 pretty-midi-0.2.9
Looking in ind

## Imports

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [63]:
from typing import Union, List, Dict, Tuple

import numpy as np
# import tensorflow as tf
import random

import jax
from jax import value_and_grad
import jax.numpy as jnp
import optax
import haiku as hk

import librosa
import pathlib
from tqdm import tqdm

import matplotlib.pyplot as plt
from pretty_midi import PrettyMIDI

from IPython.display import display, Audio

import sys
import os
FOLDER_PATH = "/content/drive/MyDrive/badpitches/v3"
sys.path.insert(0, FOLDER_PATH)

from loss import loss_dict
from constants import *
from new_model_in_jax import PosteriorgramModel
from cqt_and_hs import load_cqt_window, harmonic_stacking, load_and_cqt, cqt_windowed
from note_creation import model_output_to_notes, sonify_midi

# Training Attempt
Maybe consider adding a requirements.txt

In [9]:
rng = jax.random.PRNGKey(0)
audio_path = "drive/MyDrive/badpitches/v3/test.m4a"
audio_tensor = load_and_cqt(audio_path)
noisy_audio = audio_tensor + jax.random.normal(rng, audio_tensor.shape)

epochs = 1000
learning_rate = 0.01
optimizer = optax.adam(learning_rate)


def update_weights(weights,gradients):
    return optimizer.update(gradients, weights)

def loss_wrapper(params, state, x, y):
    out, new_state = model.apply(params, state, rng=rng, audio_tensor=x, is_training=True)
    loss_fns = loss_dict()
    print([(w, y[w].shape) for w in y], [o.shape for o in out])
    loss_yp = jnp.sum(loss_fns["contour"](y["contour"], out[0]))
    loss_yn = jnp.sum(loss_fns["note"](y["note"], out[1]))
    loss_yo = jnp.sum(loss_fns["onset"](y["onset"], out[2]))
    loss = loss_yp + loss_yn + loss_yo
    return loss, (loss, new_state)

def step(params, opt_state, state, x, y):
    # loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    grads, (loss, state) = jax.grad(loss_wrapper, has_aux=True)(params, state, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, state, loss

def model_fn(audio_tensor, is_training):
    bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, name="bn")
    normed = bn(audio_tensor, is_training)
    hs = harmonic_stacking(normed)
    pgram = PosteriorgramModel()(hs, is_training)
    
    return pgram



audio_array:  (113367,)
cqt:  (84, 222)
reshaped cqt:  (1, 84, 222)


In [10]:
def audio_info(path, window_len=ANNOT_N_FRAMES, window_num=0):
  """
  Returns the onset, contour, and note matrices as a dictionary for the appropriate
  window of the MIDI file located at the given path.
  """

  midi_fps = ANNOTATIONS_FPS

  pm = PrettyMIDI(path)

  start_frame = window_num * window_len
  end_frame = (window_num + 1) * window_len
  note_matrix = pm.get_piano_roll(fs=midi_fps)[21:109, start_frame:end_frame] / 127 #21 to 109 is piano bins

  onsets = np.array(pm.get_onsets() * midi_fps)
  logical_indices = np.logical_and(start_frame < onsets, onsets < end_frame)
  windowed_onsets = np.floor(onsets[logical_indices] - start_frame)

  onset_matrix = np.zeros_like(note_matrix, dtype=bool)
  
  for frame in windowed_onsets.astype(int):
    vec = note_matrix[: , frame]
    prev_vec = note_matrix[: , max(0, frame - 1)]
    onset_notes = np.argwhere(vec - prev_vec).ravel()
    for note in onset_notes:
      onset_matrix[note][frame] = True

  contour_array = [(row if i == 1 else np.zeros(row.shape)) for i in range(3) for row in note_matrix]

  return {
    "onset": jnp.array(onset_matrix),
    "contour": jnp.array(contour_array),
    "note": jnp.array(note_matrix)
  }

In [11]:
data_path = "/content/drive/MyDrive/badpitches_data/test_new/"
val_files = os.listdir(data_path)
files = [va.split(".")[0] for va in val_files if va.split(".")[-1] == 'midi']

### Training loop

In [15]:
BATCH_SIZE = 1 #unbatched until batch works

batched_inputs = np.zeros((BATCH_SIZE, 84, 87, 1))
batched_outputs = {
    'note': np.zeros((BATCH_SIZE, N_FREQ_BINS_NOTES, ANNOT_N_FRAMES)),
    'contour': np.zeros((BATCH_SIZE, N_FREQ_BINS_CONTOURS, ANNOT_N_FRAMES)),
    'onset': np.zeros((BATCH_SIZE, N_FREQ_BINS_NOTES, ANNOT_N_FRAMES))
}

model = hk.transform_with_state(model_fn)
for i in tqdm(range(BATCH_SIZE)):
      rand_file_name = files[random.randint(0, len(files) - 1)]
      midi_file = data_path + rand_file_name + ".midi"
      wav_file = data_path + rand_file_name + ".wav"
      audio, window_num = load_cqt_window(wav_file)
      batched_inputs[i,:,:,:] = audio
print(batched_inputs.shape)
params, state = model.init(rng, batched_inputs, True) # pass in augmented data here to train
opt_state = optimizer.init(params)
NUM_EPOCHS = 5
for _ in tqdm(range(1, NUM_EPOCHS+1)):
    
    for i in range(BATCH_SIZE):
        rand_file_name = files[random.randint(0, len(files) - 1)]
        midi_file = data_path + rand_file_name + ".midi"
        wav_file = data_path + rand_file_name + ".wav"
        audio, window_num = load_cqt_window(wav_file)
        matrices = audio_info(midi_file, window_num=window_num)
        batched_inputs[i,:,:,:] = audio
        batched_outputs['note'][i,:,:] = matrices['note']
        batched_outputs['contour'][i,:,:] = matrices['contour']
        batched_outputs['onset'][i,:,:] = matrices['onset']
    print('note', matrices['note'].shape)
    print('cont', matrices['contour'].shape)
    print('ons', matrices['onset'].shape)
    print(batched_outputs['note'].shape)
    params, opt_state, state, loss = step(params, opt_state, state, batched_inputs, batched_outputs)
    print(loss)

100%|██████████| 1/1 [00:01<00:00,  1.25s/it]

audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
(1, 84, 87, 1)



  param = init(shape, dtype)
  param = init(shape, dtype)


(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)


  0%|          | 0/5 [00:00<?, ?it/s]

audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
note (88, 172)
cont (264, 172)
ons (88, 172)
(1, 88, 172)
(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)
[('note', (1, 88, 172)), ('contour', (1, 264, 172)), ('onset', (1, 88, 172))] [(1, 84, 264, 1), (1, 84, 88, 1), (1, 84, 88, 1)]


 20%|██        | 1/5 [00:12<00:49, 12.28s/it]

6258465.5
audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
note (88, 172)
cont (264, 172)
ons (88, 172)
(1, 88, 172)
(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)
[('note', (1, 88, 172)), ('contour', (1, 264, 172)), ('onset', (1, 88, 172))] [(1, 84, 264, 1), (1, 84, 88, 1), (1, 84, 88, 1)]


 40%|████      | 2/5 [00:16<00:22,  7.34s/it]

5762605.0
audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
note (88, 172)
cont (264, 172)
ons (88, 172)
(1, 88, 172)
(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)
[('note', (1, 88, 172)), ('contour', (1, 264, 172)), ('onset', (1, 88, 172))] [(1, 84, 264, 1), (1, 84, 88, 1), (1, 84, 88, 1)]


 60%|██████    | 3/5 [00:19<00:10,  5.36s/it]

5424894.0
audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
note (88, 172)
cont (264, 172)
ons (88, 172)
(1, 88, 172)
(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)
[('note', (1, 88, 172)), ('contour', (1, 264, 172)), ('onset', (1, 88, 172))] [(1, 84, 264, 1), (1, 84, 88, 1), (1, 84, 88, 1)]


 80%|████████  | 4/5 [00:23<00:04,  4.92s/it]

5280781.0
audio_array:  (44100,)
cqt:  (84, 87)
reshaped cqt:  (1, 84, 87)
note (88, 172)
cont (264, 172)
ons (88, 172)
(1, 88, 172)
(1, 84, 264, 1)
(1, 84, 264, 16)
(1, 84, 264, 8)
(1, 84, 264, 1)
[('note', (1, 88, 172)), ('contour', (1, 264, 172)), ('onset', (1, 88, 172))] [(1, 84, 264, 1), (1, 84, 88, 1), (1, 84, 88, 1)]


100%|██████████| 5/5 [00:27<00:00,  5.47s/it]

5315520.5





## Visualize / audiate predictions
Here we will choose a sample and see the spectrogram for the sample's expected output compared to our model's, and be able to hear them as well!

The plot is the actual notes of the MIDI (digitized sheet music) file, including duration and pitch.

<!-- The middle plot illustrates the "contours," or nuances of notes (because we're only working with a piano dataset here, there is little new information to be found). The last plot describes "onsets" of notes, or when notes are first played.  -->

In [59]:
sample_index = 12 # Change as you desire! This reflects the audio file in the corresponding folder
duration = 30 # Number of seconds to generate

### Raw audio, expected audio and MIDI

In [61]:
assert 0 <= sample_index < len(files)
sample_name = files[sample_index]

midi_file = data_path + sample_name + ".midi"
wav_file = data_path + sample_name + ".wav"

plt.figure(figsize=(15,7))
mats = audio_info(midi_file, window_num=0, window_len=ANNOTATIONS_FPS * duration)

plt.matshow(mats["note"], fignum=1, aspect='auto', origin='lower')
_ = plt.axis('off')

raw_audio, _ = librosa.load(wav_file, duration=duration)
raw = Audio(raw_audio, rate=22050)

pm = PrettyMIDI(midi_file)
expected_audio = pm.synthesize(fs=22050)[:22050 * duration]
expected = Audio(expected_audio, rate=22050)

print("raw audio\nexpected audio")
display(raw, expected)

Output hidden; open in https://colab.research.google.com to view.

### Predicted audio and MIDI

In [65]:
preprocessed_input = cqt_windowed(raw_audio)
output = model.apply(params, state, rng, preprocessed_input, is_training=False)
np_mats = {
    'note': np.asarray(output[1]).T,
    'onset': np.asarray(output[2]).T,
    'contour': np.asarray(output[0]).T
}

plt.matshow(np_mats['note'], fignum=1, aspect='auto', origin='lower')
_ = plt.axis('off')

pm_predicted, _ = model_output_to_notes(np_mats, 0.5, 0.3, include_pitch_bends=False, min_note_len=0.0, melodia_trick=False)

predicted_audio = pm_predicted.synthesize(fs=22050)[:22050 * duration]
predicted = Audio(predicted_audio, rate=22050)
display(predicted)

ValueError: ignored