In [2]:
from typing import Union, List, Dict, Tuple
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import librosa
import pathlib


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
class Constants:
    #!/usr/bin/env python
    # encoding: utf-8
    #
    # Copyright 2022 Spotify AB
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

    FFT_HOP = 256
    N_FFT = 8 * FFT_HOP

    NOTES_BINS_PER_SEMITONE = 1
    CONTOURS_BINS_PER_SEMITONE = 3
    # base frequency of the CENTRAL bin of the first semitone (i.e., the
    # second bin if annotations_bins_per_semitone is 3)
    ANNOTATIONS_BASE_FREQUENCY = 27.5  # lowest key on a piano
    ANNOTATIONS_N_SEMITONES = 88  # number of piano keys
    AUDIO_SAMPLE_RATE = 22050
    AUDIO_N_CHANNELS = 1
    N_FREQ_BINS_NOTES = ANNOTATIONS_N_SEMITONES * NOTES_BINS_PER_SEMITONE
    N_FREQ_BINS_CONTOURS = ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE

    AUDIO_WINDOW_LENGTH = 2  # duration in seconds of training examples - original 1

    ANNOTATIONS_FPS = AUDIO_SAMPLE_RATE // FFT_HOP
    ANNOTATION_HOP = 1.0 / ANNOTATIONS_FPS

    # ANNOT_N_TIME_FRAMES is the number of frames in the time-frequency representations we compute
    ANNOT_N_FRAMES = ANNOTATIONS_FPS * AUDIO_WINDOW_LENGTH

    # AUDIO_N_SAMPLES is the number of samples in the (clipped) audio that we use as input to the models
    AUDIO_N_SAMPLES = AUDIO_SAMPLE_RATE * AUDIO_WINDOW_LENGTH - FFT_HOP

    DATASET_SAMPLING_FREQUENCY = {
        "MAESTRO": 5,
        "GuitarSet": 2,
        "MedleyDB-Pitch": 2,
        "iKala": 2,
        "slakh": 2,
    }


    def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) -> np.array:
        d = 2.0 ** (1.0 / (12 * bins_per_semitone))
        bin_freqs = base_frequency * d ** np.arange(bins_per_semitone * n_semitones)
        return bin_freqs


    FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
    FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)



In [5]:

def get_audio_input(
    audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
) -> Tuple[jnp.ndarray, List[Dict[str, int]], int]:
    """
    Read wave file (as mono), pad appropriately, and return as
    windowed signal, with window length = AUDIO_N_SAMPLES

    Returns:
        audio_windowed: ndarray with shape (n_windows, AUDIO_N_SAMPLES, 1)
            audio windowed into fixed length chunks
        window_times: list of {'start':.., 'end':...} objects (times in seconds)
        audio_original_length: int
            length of original audio file, in frames, BEFORE padding.

    """
    assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)

    audio_original, _ = librosa.load(str(audio_path), sr=Constants.AUDIO_SAMPLE_RATE, mono=True)

    original_length = audio_original.shape[0]
    audio_original = jnp.concatenate([jnp.zeros((int(overlap_len / 2),), dtype=jnp.float32), audio_original])
    audio_windowed, window_times = window_audio_file(audio_original, hop_size)
    return audio_windowed, window_times, original_length

def window_audio_file(audio_original: jnp.ndarray, hop_size: int) -> Tuple[jnp.ndarray, List[Dict[str, int]]]:
    """
    Pad appropriately an audio file, and return as
    windowed signal, with window length = AUDIO_N_SAMPLES

    Returns:
        audio_windowed: ndarray with shape (n_windows, AUDIO_N_SAMPLES, 1)
            audio windowed into fixed length chunks
        window_times: list of {'start':.., 'end':...} objects (times in seconds)

    """

    audio_windowed = jnp.expand_dims(
        frame(audio_original, Constants.AUDIO_N_SAMPLES, hop_size, pad_end=True, pad_value=0),
        axis=-1,
    )
    window_times = [
        {
            "start": t_start,
            "end": t_start + (Constants.AUDIO_N_SAMPLES / Constants.AUDIO_SAMPLE_RATE),
        }
        for t_start in jnp.arange(audio_windowed.shape[0]) * hop_size / Constants.AUDIO_SAMPLE_RATE
    ]
    return audio_windowed, window_times

def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1):
    """
    equivalent of tf.signal.frame
    """
    signal_length = signal.shape[axis]
    if pad_end:
        frames_overlap = frame_length - frame_step
        rest_samples = jnp.abs(signal_length - frames_overlap) % jnp.abs(frame_length - frames_overlap)
        pad_size = int(frame_length - rest_samples)
        if pad_size != 0:
            pad_axis = [0] * signal.ndim
            pad_axis[axis] = pad_size
            signal = jnp.pad(signal, pad_axis, "constant", constant_values=pad_value)
    frames=signal.unfold(axis, frame_length, frame_step)
    return frames



In [6]:
n_overlapping_frames = 30
overlap_len = n_overlapping_frames * Constants.FFT_HOP
hop_size = Constants.AUDIO_N_SAMPLES - overlap_len
audio_path = "test.m4a"
# audio_windowed, _, audio_original_length = get_audio_input(audio_path, overlap_len, hop_size)
audio = librosa.load(audio_path, sr=11025)
audio

  return f(*args, **kwargs)


(array([0.        , 0.        , 0.        , ..., 0.00098911, 0.00120955,
        0.        ], dtype=float32),
 11025)

In [28]:
from cqt_and_hs import harmonic_stacking, load_and_cqt
audio_path = "test.m4a"
audio_tensor = load_and_cqt(audio_path)
print(audio_tensor)
def new_f(audio_tensor, is_training):
    bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, name="bn")
    normed = bn(audio_tensor, is_training)
    hs = harmonic_stacking(normed)
    return hs
model = hk.transform_with_state(new_f)
rng = jax.random.PRNGKey(0)
params, state = model.init(rng, audio_tensor, True)
out = model.apply(params, state, rng=rng, audio_tensor=audio_tensor, is_training=True)
print(out)
out = model.apply(params, state, rng=rng, audio_tensor=audio_tensor, is_training=False)

  return f(*args, **kwargs)


[[[[4.4095845e-04]
   [1.0635717e-03]
   [2.0095906e-03]
   ...
   [6.9588743e-02]
   [7.1580395e-02]
   [7.6359101e-02]]

  [[1.7780378e-04]
   [6.4286956e-04]
   [1.4196741e-03]
   ...
   [2.1535021e-01]
   [2.0113581e-01]
   [1.8197703e-01]]

  [[6.1625236e-05]
   [3.4143028e-04]
   [9.7072939e-04]
   ...
   [2.6900047e-01]
   [2.4872115e-01]
   [2.2054163e-01]]

  ...

  [[0.0000000e+00]
   [1.6356005e-07]
   [3.2832439e-09]
   ...
   [1.1908718e-03]
   [1.7928916e-03]
   [1.1816506e-03]]

  [[0.0000000e+00]
   [1.1190465e-07]
   [1.7555957e-09]
   ...
   [2.1679401e-03]
   [3.6151591e-04]
   [8.0618769e-04]]

  [[0.0000000e+00]
   [6.7336671e-08]
   [1.8380945e-09]
   ...
   [2.9860791e-03]
   [1.6819689e-04]
   [3.9096191e-04]]]]
(DeviceArray([[[[ 0.00000000e+00, -3.64653379e-01, -3.44890505e-01,
                -3.50106508e-01, -3.52928787e-01, -3.44141394e-01,
                -3.40306103e-01, -3.45698833e-01]],

              [[ 0.00000000e+00, -3.64976168e-01, -3.60547245e-01,

In [9]:
from new_model_in_jax import PosteriorgramModel
import haiku as hk
import jax


audio_tensor = jnp.array(audio[0])
# print(audio_tensor.shape)
# audio_tensor = audio_tensor.reshape(2, 2, 14171)
def f(x, is_training):
    a = PosteriorgramModel()
    # a = hk.IdentityCore()
    return a(x, is_training)
model = hk.transform_with_state(f)
rng = jax.random.PRNGKey(0)
params, state = model.init(rng, out, False)
model.apply(params, state, rng=rng, x=out, is_training=False)

ValueError: No value for 'average' in 'posteriorgram_model/~/top_branch/~/bn/~/mean_ema', perhaps set an init function?

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:

from jax import value_and_grad
from loss import loss

from cqt_and_hs import harmonic_stacking, load_and_cqt
audio_path = "test.m4a"
audio_tensor = load_and_cqt(audio_path)

def new_f(audio_tensor, is_training):
    bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, name="bn")
    normed = bn(audio_tensor, is_training)
    hs = harmonic_stacking(normed)
    return hs
model = hk.transform_with_state(new_f)
rng = jax.random.PRNGKey(0)
params, state = model.init(rng, audio_tensor, True)

epochs = 1000
learning_rate = jnp.array(0.001)

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

def loss_wrapper(params, x, y):
    nonlocal state
    out, state = model.apply(params, state, rng=rng, x=x, is_training=True)
    loss_fns = loss()
    return loss_fns["contour"](y[0], out[0]) + loss_fns["note"](y[1], out[1]) + loss_fns["onset"](y[2], out[2])


for i in range(1, epochs+1):
    loss, param_grads = value_and_grad(loss_wrapper)(params, out, out)
    params = jax.tree_map(UpdateWeights, params, param_grads)

    if i%100 == 0:
        print("MSE : {:.2f}".format(loss))

  return f(*args, **kwargs)


SyntaxError: no binding for nonlocal 'state' found (3984634688.py, line 24)