<a href="https://colab.research.google.com/github/tollycollins/Beat-Tracker-DL/blob/main/Joint_Beat_and_Tempo_Tracker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
t



# Joint Beat and Tempo Tracker

## Imports

In [None]:
# link to Google Drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount=False)


Mounted at /content/gdrive


In [None]:
# # madmom library - Contains some code from Bock (2019)
# !git clone --recursive https://github.com/CPJKU/madmom.git

!pip install madmom

In [None]:
# ASAP dataset
!git clone https://github.com/fosfrancesco/asap-dataset.git

In [None]:
!pip install pretty_midi

In [None]:
!pip install mir_eval

In [None]:
# !pip install compress_pickle

In [None]:
from pathlib import Path
import os
import shutil
import json
import h5py
import pickle
# from compress_pickle import dump, load
from typing import List
import copy
import random
import math

import pretty_midi as pm
import numpy as np 
import matplotlib.pyplot as plt
import librosa.display as display
import librosa
import soundfile as sf
from scipy import signal
from scipy.interpolate import interp1d
from scipy.ndimage import maximum_filter1d as maxFilt
import mir_eval

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.utils import weight_norm

# from tensorboardcolab import *
# from torch.utils.tensorboard import SummaryWriter

from madmom.features.beats import DBNBeatTrackingProcessor


# Data

In [None]:
# ASAP Metadata class - represents a dataset folder


class ASAPMetadata:
  """
  metadata: {'directory': str, 
             'names': list, 
             'original_names': str,
             'lengths': list, 
             'features': list, 
             'sample_rate': int, 
             'tempo_range': [int, int], 
             'note_range': [int, int), 
             'rests': bool}

  Note: Notes given in range 0-88 inclusive
  Note: Tempo given in range self.tempo_range 
  """

  DEFAULT_SAMPLE_RATE = 100
  DEFAULT_DATASET_PATH = "./asap-dataset"
  DEFAULT_BASE_PATH = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio"

  DATASET_NAME = "asap-beat-tracking"
  

  def __init__(self, 
               json_data=None, 
               name=None, 
               base=None, 
               tempo_range=(10, 360), 
               tempo_target_type='continuous',
               note_range=(0, 88), 
               rests=False, 
               num_files=None):
    """
    tempo_range: allowable tempo range for pieces (if they have tempi outside this range, 
      they will not be added to the dataset)
    note_range: permitted piano keyboard notes from the piece
    rests: if True, allocate rests to top column (when no other note is playing)
    """

    self.asap_base_path = ASAPMetadata.DEFAULT_DATASET_PATH
    self.json_data = json_data
    if not json_data:
      with open(Path(self.asap_base_path, 'asap_annotations.json')) as json_file:
        self.json_data = json.load(json_file)
    self.original_dataset_size = len(self.json_data.keys())

    # features available
    self.features = {'pianoroll': False,
                     'noteroll': False,
                     'velocityroll': False,
                     'beats': False,
                     'downbeats': False,
                     'tempo': False,
                     'beatTypes': False,
                     'timeSignatures': False,               
                     'key_signatures': False}

    # initialise with default sample rate
    self.sample_rate = ASAPMetadata.DEFAULT_SAMPLE_RATE

    self.tempo_range = tempo_range  # [low, high)
    self.tempo_target_type = tempo_target_type
    self.note_range = note_range
    self.rests = rests

    # set a limit on the number of files to create
    self.num_files = num_files

    self.metadata = {}

    # folder path and name
    self.base = ASAPMetadata.DEFAULT_BASE_PATH if not base else base
    self.name = name
    if name:
      # scan directory for existing name directory
      if self.folder_exists(name):
        self.link(name)

  
  def get_metadata(self, features=None):
    """
    Request metadata dictionary with given features
    """
    info = self.metadata
    if not info:
      raise RuntimeError("Metadata is empty")

    if features:
      # only keep desired features
      for f in features:
        if not self.features[f]:
          raise RuntimeError(f"feature {f} is not available ")
      info['features'] = features

    return info


  def change_base(self, base):
    """
    Change base directory attribute
    """
    self.base = base


  def link(self, name, reset_features=True):
    """
    Link class to a folder name
    """
    directory = self.get_path(name)
    if not self.folder_exists(name):
      raise FileNotFoundError(f"Folder {directory} does not exist")

    self.name = name

    # load metadata
    self.metadata = self.load_metadata()

    # update features
    if reset_features:
      try:
        self.reset_features(self.metadata['features'])
      except KeyError:
        self.reset_features()

    # update sample rate
    try:
      self.sample_rate = self.metadata['sample_rate']
    except KeyError:
      pass
    
    print(f"Data Folder Manager linked to {self.get_path(name)}")


  def new_folder(self, name, reset_features=True):
    """
    Create and link to a new folder
    """
    directory = self.get_path(name)
    if self.folder_exists(name):
      print(f"Folder {directory} already exists")
      return

    os.makedirs(directory)
    self.name = name
    # reset metadata and features
    self.metadata = {}
    self.reset_features()


  def create_data_files(self, 
                        features=None, 
                        overwrite=False, 
                        name_list=None):
    """
    Dumps np arrays of examples and corresponding ground truths. 
    """
    path = self.get_path(self.name)
    # check we have a valid folder name
    if not self.folder_exists(self.name):
      raise RuntimeError(f"The folder {path} does not exist")

    metadata = {}
    empty = True
    # load metadata
    try:
      metadata = self.load_metadata()
      # reset self.features
      self.reset_features(metadata['features'])
      empty = False
    except KeyError:
      self.reset_features()
    
    # choose features to include
    if features:
      for feature in features:
        if self.features[feature] == False:
          self.features[feature] = True
          # force overwrite if features do not match
          if metadata != {}:
            overwrite = True
      for feature, val in self.features.items():
        if feature not in features:
          self.features[feature] = False
          # force overwrite if features do not match
          if metadata != {}:
            overwrite = True

    # alter the sample rate of all data (overwrite data)
    try:
      if metadata['sample_rate'] != self.sample_rate:
        overwrite = True
    except KeyError:
        pass

    if overwrite:
      # wipe directory
      try:
        shutil.rmtree(path)
        self.new_folder(self.name, reset_features=False)
      except OSError as e:
        print("Error: %s : %s" % (path, e.strerror))

      # reset metadata
      metadata = {'directory': path, 'names': [], 'lengths': []}

    # generator for examples and labels
    dataGen = self.gen_data(name_list=name_list)

    # generate and save data
    counter = 0
    num_examples = 0
    num_already = 0
    names = []
    original_names = []
    lengths = []
    for fpath, dataDict in dataGen:
      counter += 1

      # check if file limit has been reached
      if self.num_files and counter > self.num_files:
        break

      trackname = os.path.split(fpath)[1]
      trackname = os.path.splitext(trackname)[0]
      # name = trackname + '.gz'
      name = trackname + '.hdf5'

      # check for duplicates
      while name in names:
        penult_char = name[-7]
        suffix = '_0'
        if penult_char == '_':
          last_char = name[-6]
          suffix = str(int(last_char) + 1)
          trackname = trackname[: -1]
        trackname += suffix
        # name = trackname + '.gz'
        name = trackname + '.hdf5'
        print(f'Duplicate {name} renamed')

      filename = os.path.join(path, name)

      # double check for duplicates
      if name in names:
        print(f'Duplicate {name} not processed')
        continue

      # save data to file
      if not os.path.isfile(filename) or overwrite:
            
        with h5py.File(filename, 'w') as hf:
          hf.create_dataset('length', data=dataDict['length'])
          if 'pianoroll' in dataDict.keys():
            hf.create_dataset("pianoroll", shape=dataDict['pianoroll'].shape, data=dataDict['pianoroll'])
          if 'beats' in dataDict.keys():
            hf.create_dataset("beats", shape=dataDict['beats'].shape, data=dataDict['beats'])
          if 'downbeats' in dataDict.keys():
            hf.create_dataset("downbeats", shape=dataDict['downbeats'].shape, data=dataDict['downbeats'])
          if 'tempo' in dataDict.keys():
            hf.create_dataset("tempo", dataDict['tempo'].shape, data=dataDict['tempo'])

        # with open(filename, 'wb') as f:
        #   dump(dataDict, f)

        # update metadata dict
        names.append(name)
        original_names.append(fpath)
        lengths.append(dataDict['length'])

        print(f'Processed {counter} out of {self.original_dataset_size}. File created: {name}')
        num_examples += 1
      else:
        if empty:
          names.append(name)
          lengths.append(dataDict['length'])
        num_already += 1

    # save metadata
    feat = []
    for f, val in self.features.items():
      if val:
        feat.append(f)
    metadata['names'] = names
    metadata['original_names'] = original_names
    metadata['lengths'] = lengths
    metadata['features'] = feat
    metadata['sample_rate'] = self.sample_rate
    metadata['tempo_range'] = self.tempo_range
    metadata['tempo_target_type'] = self.tempo_target_type
    metadata['note_range'] = self.note_range
    metadata['rests'] = self.rests
    metadata['directory'] = self.get_path(self.name)

    save_path = os.path.join(self.get_path(self.name), 'metadata.json')
    with open(save_path, 'w') as fp:
      json.dump(metadata, fp)    
    self.metadata = metadata

    print(f"\n {num_examples} files created and {num_already} already present " \
          f"out of {self.original_dataset_size} originals")
    

  def gen_data(self, name_list=None):
    """
    generator for MIDI piano rolls and asociated ground truths from MIDI files
    """
    for path in self.json_data.keys():

      # check track name (for selecting test samples)
      if name_list is not None:
        trackname = os.path.split(path)[1]
        trackname = os.path.splitext(trackname)[0]
        if trackname not in name_list:
          continue

      data = {}

      # get next piano roll
      piece = pm.PrettyMIDI(os.path.join(self.asap_base_path, path)).get_piano_roll(100)

      # slice to note range
      piece = piece[slice(*self.note_range), :]

      if self.rests:
        rests = np.where(np.any(piece, axis=0, keepdims=True), 0, 1)
        piece = np.concatenate((rests, piece))
      
      # change data type
      piece = piece.astype(np.int8)

      if self.features['pianoroll']:
        data['pianoroll'] = piece

      num_time_points = piece.shape[1]
      data['length'] = num_time_points

      # get associated beat and downbeat labels
      beats = self.json_data[path]["performance_beats"]
      downbeats = self.json_data[path]["performance_downbeats"]
      # get positions of '1's in labels arrays
      # note: sample rate set to 100
      beatPositions = np.multiply(np.around(beats, decimals=2), 100).astype(int)
      downbeatPositions = np.multiply(np.around(downbeats, decimals=2), 100).astype(int)

      # make beat and downbeat labels
      if self.features['beats']:
        data['beats'] = self.make_beat_annotations(beatPositions, num_time_points)
      if self.features['downbeats']:
        data['downbeats'] = self.make_beat_annotations(downbeatPositions, num_time_points)

      if self.features['tempo']:
        # get tempo labels
        tempo = np.zeros(num_time_points)

        # inter-beat-intervals
        IBIs = np.squeeze(np.diff(beats))

        # create histogram of IBI values
        for i in range(len(beatPositions) - 1):
          tempo[beatPositions[i]: beatPositions[i + 1]] = IBIs[i]
        # fill in ends
        tempo[: beatPositions[0]] = IBIs[0]
        tempo[beatPositions[-1]: ] = IBIs[-1]

        # use histogram to smooth with 15-frame Hamming window
        window = np.hamming(15)
        smoothTempo = signal.convolve(tempo, window, mode='same')
        smoothTempo = smoothTempo / np.sum(window)

        # apply quadratic interpolation to smoothed beat values
        beatPositions_for_tempo = list(sorted(set(beatPositions)))
        smoothed = smoothTempo[beatPositions_for_tempo]
        # add values for the start and end positions
        smoothed = np.concatenate(([smoothed[0]], smoothed, [smoothed[-1]]))
        tempoPositions = np.concatenate(([0], beatPositions_for_tempo, [num_time_points - 1]))

        try:
          f_interp = interp1d(tempoPositions, smoothed, kind='quadratic')
        except ValueError:
          print(f"Beats not annotated consecutively for {path}. Example not included in dataset.")
          continue
        tempo = np.arange(num_time_points)
        tempo = f_interp(tempo)
        

        if self.tempo_target_type == 'discrete':
          # convert from IBIs (in secs) to BPM
          tempo = np.round(np.divide(60.0, tempo)).astype(int)

          # map to BPM probability distribution matrix
          tempoLabels = np.zeros((self.tempo_range[1] + 2, num_time_points))
          tempoHist = [0.25, 0.5, 1, 0.5, 0.25]
          skip = False
          for j, t in enumerate(tempo):
            # check that tempi fall within valid range
            if not self.tempo_range[0] <= t <= self.tempo_range[1]:
              print(f"Tempo of {t} out of range for {path}. Example not included in dataset.")
              skip = True
              break
            else:
              tempoLabels[t - 3: t + 2, j] = tempoHist

          if skip:
            continue

          # trim tempo to self.tempo_range
          data['tempo'] = tempoLabels[self.tempo_range[0] - 1: self.tempo_range[1], :]

        elif self.tempo_target_type == 'continuous': 
          if self.tempo_range is not None:
            # check that tempo is in range
              if not (min(tempo) >= self.tempo_range[0] / 60.0) and \
                     (max(tempo) <= self.tempo_range[1] / 60.0):
                print(f"Tempo of {min(tempo)} or {max(tempo)} out of range for {path}. " \
                      f"Example not included in dataset.")
                continue
          # convert from IBIs (in secs) to BPS
          tempo = np.divide(1.0, tempo)
          data['tempo'] = tempo

        else:
          raise RuntimeError("Invalid tempo type for self.tempo_target_type.  " \
                             "Should be 'continuous' or 'discrete'.  ")

      yield path, data


  def load_metadata(self):
    """
    Load the metadata json file from the linked folder
    """
    try:
      with open(Path(self.get_path(self.name), 'metadata.json')) as json_file:
        file = json.load(json_file)
    except FileNotFoundError:
      print("Metadata file could not be loaded")
      file = {}
    return file


  def folder_exists(self, name):
    return os.path.exists(self.get_path(name))


  def get_path(self, name):
    """
    Return a string to the current folder directory
    """
    return os.path.join(self.base, name)


  def reset_features(self, features=None):
    """
    features (list): 
      if None: change all features to 'False'
      if given: change these features to 'True' and all others to 'False'
    """
    for feature, val in self.features.items():
      val = False
    if features is not None:
      for f in features:
        self.features[f] = True


  @staticmethod
  def make_beat_annotations(beatPos, length):
    # create beat labels (and 0.5 either side), given a list of beat indices
    beatLabels = np.zeros(length)
    for t in beatPos:
      beatLabels[t - 1] = 0.5
      beatLabels[t] = 1
      beatLabels[t + 1] = 0.5
    return beatLabels

    

In [None]:
# split metadata dictionary into train, validation and test
# Note: split metadata by files so samples from the same file are not in both training and test data

def split_dataset(metadata, 
                  max_examples=None,
                  test_ratio=0.2, 
                  validation_ratio=0, 
                  seed=1):
  """
  Splits metadata into train [validation, ] and test dictionaries

  metadata: {'directory': str, 
             'names': list, 
             'original_names': str,
             'lengths': list, 
             'features': list, 
             'sample_rate': int, 
             'tempo_range': [int, int], 
             'note_range': [int, int), 
             'rests': bool}
  """
  def make_split(data, ratio):
    length = min(len(data['names']), max_examples) if max_examples else len(data['names'])
    train_size = int((1 - ratio) * length)
    test_size = length - train_size  

    indices = np.arange(length)

    train_indices, test_indices = random_split(indices, [train_size, test_size], 
                                           generator=torch.Generator().manual_seed(seed))

    train_names, test_names, train_lengths, test_lengths = [], [], [], []
    for index in train_indices:
      train_names.append(data['names'][index])
      train_lengths.append(data['lengths'][index])
    for index in test_indices:
      test_names.append(data['names'][index])
      test_lengths.append(data['lengths'][index])

    train = copy.deepcopy(data)
    test = copy.deepcopy(data)

    train['names'] = train_names
    train['lengths'] = train_lengths
    test['names'] = test_names
    test['lengths'] = test_lengths

    return train, test

  metatrain, metatest = make_split(metadata, test_ratio)

  if not validation_ratio:
    return metatrain, None, metatest
  
  metatrain, metavalidation = make_split(metatrain, validation_ratio)

  return metatrain, metavalidation, metatest



In [None]:
# Dataset class
# https://github.com/cheriell/ICASSP2021-A2S/blob/main/audio2score/data/prtransdataset.py
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html


class myDataset(Dataset):
  """
  Abstraction of data file system.

  Provides pitch and time augmentation of data for each request.
  Randomly selects a sample from within a give window.

  Parameters:
    metadata: Dictionary describing data files:
              {'directory': str, 
              'names': list, 
              'original_names': str,
              'lengths': list, 
              'features': list, 
              'sample_rate': int, 
              'tempo_range': [int, int], 
              'note_range': [int, int), 
              'rests': bool}

    sample_length: sumber of time samples per example

    pitch_aug_range: range of number of semitones for random pitch augmentation: tuple(int, int)

    time_aug_range: range of ratios for random time warping, 
      ratios should be within in the range (0.5, 2): tuple(float, float).
  """
  def __init__(self, 
               metadata: dict, 
               sample_length: int, 
               pitch_aug_range=None, 
               time_aug_range=None):
    
    super().__init__()

    self.metadata = metadata

    self.sample_length = sample_length
    self.pitch_aug_range = pitch_aug_range
    self.time_aug_range = time_aug_range

    self.index_dict = self.indices_to_positions()


  def __len__(self):
    return len(self.index_dict)


  def __getitem__(self, index: int) -> dict:
    """
    Returns a sample from the dataset corresponding to the given 'index'.  
    Beginning time point of the sample is selected randomly from the range between successive hops, 
      according to a uniform distribution.  
    Data is augmented by time or pitch adjustment, according to the respective attributes.  
    """
    # get the data from file
    f_name = self.index_dict[index][0]
    filename = os.path.join(self.metadata['directory'], f_name)

    data = {}

    # load file
    with h5py.File(filename, "r") as hf:
      data['length'] = hf['length'][()]
      if 'pianoroll' in hf:
        data['pianoroll'] = hf['pianoroll'][:, :]
      if 'beats' in hf:
        data['beats'] = hf['beats'][:]
      if 'downbeats' in hf:
        data['downbeats'] = hf['downbeats'][:]
      if 'tempo' in hf:
        data['tempo'] = hf['tempo'][:]

    # with open(filename, 'rb') as f:
    #   data = load(f)

    # check
    if data['pianoroll'].shape[1] != self.index_dict[index][2]:
      raise RuntimeError(f"file {filename} incorrect length: \n" \
        f"pianoroll length: {data['pianoroll'].shape[1]}, dictionary length: {self.index_dict[index][2]}")

    # start index for slicing track and labels
    # start index is slected randomly in range between hop start points
    start = random.randint(*self.index_dict[index][1])

    # set augmentation parameters
    # pitch augmentation
    shift = 0
    if self.pitch_aug_range is not None:
      # choose augmentation parameter
      shift = random.randint(*self.pitch_aug_range)
    # time augmentation
    num_samples = self.sample_length
    if self.time_aug_range is not None:
      num_samples, start = self.augment_time(start, index)

    # initialise example dictionary
    example = {}
    # iterate through features and add to dictionary
    for feature in self.metadata['features']:
      item = data[feature]

      # print(f"{feature} shape: {item.shape}")

      item = np.squeeze(item)
      if len(item.shape) == 1:
        item = np.expand_dims(item, 0)

      item = item[:, start: start + num_samples]

      # print(f"{feature} shape: {item.shape}")

      # resample if it was time-augmented
      if item.shape[1] != self.sample_length:
        item = signal.resample(item, self.sample_length, axis=1)

      # apply pitch augmentation
      if shift and (feature in ['pianoroll', 'noteroll', 'velocityroll', 'key_signatures']):
        item = self.augment_pitch(item, feature, shift)
      
      # convert to Tensor
      item = torch.from_numpy(item).float()
      item = torch.squeeze(item)

      # add to example dictionary
      example[feature] = item

      # add name information
      example['name'] = f_name
      example['original_name'] = self.index_dict[index][3]

      # start sample information for testing
      example['start'] = start

    return example


  def indices_to_positions(self):
    """
    Creates a lookup dict for each index as {index: (filename, min_pos, max_pos)}
    The min_pos and max_pos define an interval in which the starting position of 
      the example can be taken.  
    """
    positions = {}
    index = 0
    for i, name in enumerate(self.metadata['names']):
      length = self.metadata['lengths'][i]
      original_name = self.metadata['original_names'][i]

      # check that piece is long enough
      if length < self.sample_length:
        continue

      # number of examples in this file
      num = int(length / self.sample_length)

      # create an entry for each example
      for hop in range(num):
        start = hop * self.sample_length

        # find last possible start point, avoiding the slice going out of range
        end = min(start + self.sample_length - 1, length - self.sample_length)

        positions[index] = (name, (start, end), length, original_name)

        index += 1

    return positions


  def augment_time(self, start, index):
    """
    Augment time by choosing a different number of time samples to take from the database, 
      and this will then be resampled to the correct number of points.  
    """
    # choose augmentation parameter
    warp = random.uniform(*self.time_aug_range)
    # new sample length required
    num_samples = self.sample_length
    num_samples = (1 / warp) * num_samples
    # move start if necessary
    start = min(start, self.index_dict[index][2] - num_samples)
    # shorten length if necessary
    if start < 0:
      num_samples = num_samples + start
      start = 0
    
    return int(num_samples), int(start)


  def augment_pitch(self, item, feature, shift):
    
    if shift == 0:
      return item

    # pianoroll
    if feature in ['pianoroll', 'noteroll', 'velocityroll']:
      padding = np.zeros((np.abs(shift), item.shape[1]))
      if self.metadata['rests']:
        rests = item[-1, :]
        item = item[:-1, :]
      if shift > 0:
        item = item[:-shift, :]
        item = np.concatenate((padding, item))
      else:
        item = item[-shift: , :]
        item = np.concatenate((item, padding))
      if self.metadata['rests']:
        item = np.concatenate((item, rests))
        
    # key_signatures (potential extension)

    return item



# Models

## General Modules

In [None]:
# To remove padding added for dilated convolution
# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py [Bai]


class Chomp1d(nn.Module):
  """
  Removes padding
  """
  def __init__(self, padding):
    super(Chomp1d, self).__init__()

    self.chomp = padding


  def forward(self, x):

    x = x[:, :, : -self.chomp].contiguous()
    return x



In [None]:
# utility function for initialising layers

def init_layer(layer):
    """Initialize a Linear or Convolutional layer. 
    Ref: He, Kaiming, et al. "Delving deep into rectifiers: Surpassing 
    human-level performance on imagenet classification." Proceedings of the 
    IEEE international conference on computer vision. 2015.
    """
    if layer.weight.ndimension() == 4:
        (n_out, n_in, height, width) = layer.weight.size()
        n = n_in * height * width
    
    elif layer.weight.ndimension() == 3:
        (n_out, n_in, width) = layer.weight.size()
        n = n_in * width

    elif layer.weight.ndimension() == 2:
        (n_out, n) = layer.weight.size()

    std = math.sqrt(2. / n)
    scale = std * math.sqrt(3.)
    layer.weight.data.uniform_(-scale, scale)

    if layer.bias is not None:
        layer.bias.data.fill_(0.)

        

In [None]:
# Convolution blocks for input

class ConvBlock2D(nn.Module):

  def __init__(self, 
               in_channels, 
               out_channels, 
               kernel, 
               padding, 
               maxpool=None, 
               activation=nn.ELU, 
               dropout=0.1):
    
    super(ConvBlock2D, self).__init__()

    self.maxpool = maxpool
    self.dropout = dropout

    self.conv = nn.Conv2d(in_channels, out_channels, kernel, padding=padding)
    self.act = activation()
    if dropout:
      self.dropout = nn.Dropout(dropout)
    if maxpool is not None:
      self.pool = nn.MaxPool2d(maxpool) 
    
    self.init_weights()


  def init_weights(self):
    init_layer(self.conv)


  def forward(self, x):
    out = self.conv(x)
    if self.dropout:
      out = self.dropout(out)
    if self.maxpool is not None:
      out = self.pool(out)
    out = self.act(out)
    return out



In [None]:
#temporal convolution modules
# https://www.kaggle.com/code/ceshine/pytorch-temporal-convolutional-networks/script
# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py [Bai]
# https://github.com/ben-hayes/beat-tracking-tcn/blob/master/beat_tracking_tcn/datasets/ballroom_dataset.py
# https://github.com/CPJKU/madmom/blob/main/madmom/ml/nn/layers.py [Boeck]


class DilatedSection(nn.Module):
  """
  Dilated convolution sequential module
  """
  def __init__(self, 
               num_inputs,
               num_outputs, 
               kernel_size, 
               dilation, 
               padding, 
               activation, 
               dropout=0.1, 
               weightnorm=True, 
               causal=False):
    """
    Note: padding must be in form (left, right)
    """
    super(DilatedSection, self).__init__()

    self.conv = nn.Conv1d(num_inputs, num_outputs, kernel_size=kernel_size, padding=padding, 
                          dilation=dilation)
    if weightnorm:
      self.conv = weight_norm(self.conv)
    
    self.chomp = None
    self.causal = causal
    if causal:
      self.chomp = Chomp1d(padding)

    self.act = activation()
    self.dropout = nn.Dropout(dropout)

    self.init_weights()


  def init_weights(self):
    init_layer(self.conv)
  

  def forward(self, x):

    # assert x.shape == (4, 16, 4349), f"x.shape: {x.shape}"

    out = self.conv(x)
    if self.causal:
      out = self.chomp(out)
    out = self.dropout(out)
    out = self.act(out)
    return out


class TemporalBlock(nn.Module):
  """
  TCN block

  Option for skip connection after dilated convolutions

  Sequence of n dilations summed with residual

  Option for 1x1 convolution (res_conv) when there is no downsampling of layers
  """
  def __init__(self, 
               num_inputs, 
               num_outputs, 
               kernel_size, 
               dilation, 
               padding, 
               dilated_act=nn.ReLU, 
               block_act=nn.ReLU, 
               dropout=0.1, 
               weightnorm=True, 
               skip=False,
               res_conv=True, 
               causal=False):
    """
    Note: kernel_size, stride, dilation and padding must all have the same length
    """

    super(TemporalBlock, self).__init__()

    self.skip = skip
    self.res_conv = res_conv
    # force 1D convolution True if downsampling required
    if num_inputs != num_outputs:
      self.res_conv = True

    # temporal convolutions ections
    dilated_layers = []
    for i in range(len(kernel_size)):
      in_channels = num_inputs if i == 0 else num_outputs[i - 1]
      dilated_layers += [DilatedSection(in_channels, num_outputs, kernel_size[i],
                                       dilation[i], padding[i], dilated_act, 
                                       dropout=dropout, weightnorm=weightnorm, causal=causal)]
    self.dilated = nn.Sequential(*dilated_layers)
    
    # optional skip connection output after dilated layers
    if skip:
      self.conv1d_skip = nn.Conv1d(num_outputs, num_outputs, 1)
      self.dilated = nn.Sequential(self.dilated, self.conv1d_skip)
    
    # optional 1D convolution for internal skip connection
    # also provides downsampling if number of channels changes
    self.downsample = nn.Conv1d(num_inputs, num_outputs, 1) if self.res_conv else None

    self.block_act = block_act()

    self.init_weights()


  def init_weights(self):
    if self.skip:
      init_layer(self.conv1d_skip)
    init_layer(self.downsample)
  
  
  def forward(self, x):
    
    out = self.dilated(x)
    res = x if self.downsample is None else self.downsample(x)
    skip = out if self.skip else None

    out = self.block_act(out + res)
    return out, skip



## Sections


In [None]:
# convolutional section mirroring application for audio in Bock(2019)

class Conv_original(nn.Module):
  
  def __init__(self, 
               in_channels=1, 
               conv_channels=(16, 16, 16), 
               kernels=((3, 3), (3, 3), (8, 1)), 
               padding=((0, 1), (0, 1), (0, 0)), 
               maxpool=((3, 1), (3, 1), None), 
               dropout=0.1, 
               act=nn.ELU):
    
    super(Conv_original, self).__init__()

    # convolution sections
    layers = []
    for i in range(len(conv_channels)):
      input = in_channels if i == 0 else conv_channels[i - 1]
      layers += [ConvBlock2D(input, conv_channels[i], kernels[i], padding=padding[i], 
                             maxpool=maxpool[i], activation=act, dropout=dropout)]
    
    self.net = nn.Sequential(*layers)
    

  def forward(self, x):
    
    return self.net(x)



In [None]:
# Original TCN section

class TCN_original(nn.Module):
  """
  Full temporal convolutional section
  """
  def __init__(self, 
               num_inputs=16, 
               num_channels=None, 
               kernel_sizes=None, 
               dilations=None, 
               dilated_act=nn.ELU, 
               block_act=nn.ELU, 
               dropout=0.1, 
               weightnorm=True, 
               skip=False, 
               res_conv=True, 
               causal=False):
    """
    Note: num_channels, kernel_sizes and dilations must all have the same length
    """
    super(TCN_original, self).__init__()

    if num_channels is None:
      num_channels=[16] * 10
    if kernel_sizes is None:
      kernel_sizes=[[5]] * 10
    if dilations is None:
      dilations = [[2 ** n] for n in range(10)]

    self.layers = []
    num_layers = len(num_channels)
    for i in range(num_layers):
      # number of input channels
      in_channels = num_channels[i - 1] if i > 1 else num_inputs

      sections = len(dilations[i])

      # calculate padding sizes
      total_pad = [(kernel_sizes[i][n] - 1) * dilations[i][n] for n in range(sections)]
      if causal: 
        padding = [p for p in total_pad]
      else:
        padding = [int(p / 2) for p in total_pad]

      self.layers += [TemporalBlock(in_channels, num_channels[i], kernel_sizes[i], 
                                    dilations[i], padding, dilated_act=dilated_act, dropout=dropout, 
                                    weightnorm=weightnorm, skip=skip, res_conv=res_conv, causal=causal)]
    # self.net = nn.Sequential(*layers)
      self.net = nn.ModuleList(self.layers)

  def forward(self, x):

    out = x
    skips = []
    for layer in self.net:
      out, skip = layer(out)
      skips.append(skip)
      
    return out, skips



In [None]:
# Convolutional block for tempo network

class Tempo_conv_block(nn.Module):
  """
  Full temporal convolutional section
  """
  def __init__(self, 
               num_inputs=16, 
               num_channels=(16, 16, 16, 16), 
               kernel_sizes=(3, 3, 3, 3), 
               dilations=(1, 1, 1, 1), 
               padding=(1, 1, 1, 1), 
               maxpool=(2, 2, 2, 1),
               act=nn.ELU, 
               dropout=0.1, 
               weightnorm=True):
    """
    Note: num_channels, kernel_sizes and dilations must all have the same length
    """
    super(Tempo_conv_block, self).__init__()

    layers = []
    num_layers = len(num_channels)
    for i in range(num_layers):
      in_channels = num_inputs if i == 0 else num_channels[i - 1]
      layers += [DilatedSection(num_inputs, num_channels[i], kernel_sizes[i], dilations[i], 
                                padding[i], act, dropout=dropout, weightnorm=weightnorm)]

      layers += [nn.BatchNorm1d(in_channels), 
                 nn.Conv1d(in_channels, num_channels[i], kernel_sizes[i], padding=padding[i], 
                           dilation=dilations[i]), 
                 nn.MaxPool1d(maxpool[i]), 
                 nn.Dropout(dropout),
                 act()]

    self.net = nn.Sequential(*layers)


  def forward(self, x):

    out = self.net(x)
      
    return out



## Networks

In [None]:
# TCN beat tracker
# https://github.com/CPJKU/madmom/tree/main/madmom


class TCNbeatTracker(nn.Module):
  
  def __init__(self, 
               downbeats=False, 
               conv=Conv_original, 
               conv_channels=(16, 16, 16), 
               conv_kwargs={}, 
               tcn=TCN_original, 
               tcn_args=[], 
               tcn_kwargs=None, 
               dropout=0.1):
    
    super(TCNbeatTracker, self).__init__()

    if tcn_kwargs is None:
      tcn_kwargs={'num_channels': [16]*10}

    # convolution feature extractor
    # if conv == "original":
      # self.conv = Conv_original(conv_channels=conv_channels, **conv_kwargs)  
    self.conv = conv(conv_channels=conv_channels, **conv_kwargs)  

    # TCN section with a wide receptive field
    self.tcn = tcn(conv_channels[-1], *tcn_args, dropout=dropout, **tcn_kwargs)

    self.dropout = nn.Dropout(dropout)

    # Time-distributed fully-connected layer
    self.beat_time_dist = nn.Conv1d(tcn_kwargs['num_channels'][-1], 1, 1)
    # self.beat_sigmoid = nn.Sigmoid()

    self.downbeats = downbeats
    if downbeats:
      # Time-distributed fully-connected layer
      self.downbeat_time_dist = nn.Conv1d(tcn_kwargs['num_channels'][-1], 1, 1)
      # self.downbeat_sigmoid = nn.Sigmoid()

    
  def forward(self, x):

    (_, notes, time_len) = x.shape
    # reshape the input into 4D format (batch_size, channels, notes, time_len)
    x = x.view(-1, 1, notes, time_len)

    out = self.conv(x)

    out = out.view(-1, out.shape[1], out.shape[3])
    out, skips = self.tcn(out)

    out = self.dropout(out)

    beats = self.beat_time_dist(out)
    # beats = self.beat_sigmoid(beats)
    beats = torch.squeeze(beats)

    downbeats = None
    if self.downbeats: 
      downbeats = self.downbeat_time_dist(out)
      # downbeats = self.downbeat_sigmoid(downbeats)
      downbeats = torch.squeeze(downbeats)

    return beats, downbeats, None


  @staticmethod
  def loss_functions():

    def loss(beats, beat_labels, downbeats=None, downbeat_labels=None, downbeat_weight=0.5, 
             tempo=None, tempo_labels=None, tempo_weight=0, pos_weight=None):
      beat_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(beats, beat_labels)
      downbeat_loss = 0
      if downbeats:
        downbeat_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight*2)(downbeats, downbeat_labels)
      return beat_loss + downbeat_weight * downbeat_loss 

    def beat_loss(beats, beat_labels, pos_weight=None):
      return nn.BCEWithLogitsLoss(pos_weight=pos_weight)(beats, beat_labels)

    def downbeat_loss(downbeats, downbeat_labels, pos_weight=None):
      return nn.BCEWithLogitsLoss(pos_weight=pos_weight*2)(downbeats, downbeat_labels)

    return {'loss': loss, 'beat_loss': beat_loss, 'downbeat_loss': downbeat_loss}
    


In [None]:
# CNN tempo classifier

class CNNtempoTracker(nn.Module):

  def __init__(self, 
               downbeats=False, 
               conv=Conv_original, 
               conv_kwargs={'conv_channels': (16, 16, 16)}, 
               conv2=Tempo_conv_block, 
               conv2_kwargs={'num_channels': (16, 16, 16, 16)}, 
               tempo_target_type='continuous'):
    
    super(CNNtempoTracker, self).__init__()

    # convolution feature extractor
    self.conv = conv(**conv_kwargs)  
    
    # TCN section with a wide receptive field
    self.conv2 = conv2(conv_kwargs['conv_channels'][-1], **conv2_kwargs)

    if tempo_target_type == 'continuous':
      self.time_dist = nn.Conv1d(conv2_kwargs['num_channels'][-1], 1, 1)

    elif tempo_target_type == 'discrete':
      pass

    self.init_weights()


  def init_weights(self):

    init_layer(self.time_dist)
  

  def forward(self, x):

    (_, notes, time_len) = x.shape
    # reshape the input into 4D format (batch_size, channels, notes, time_len)
    x = x.view(-1, 1, notes, time_len)

    out = self.conv(x)
    out = out.view(-1, out.shape[1], out.shape[3])

    out = self.conv2(out)

    out = self.time_dist(out)

    # # use interpolation to get the correct time reolution for target comparison
    # out = nn.functional.interpolate(out, size=(time_len), mode='nearest')

    return None, None, out


  @staticmethod
  def loss_functions():
    
    def loss(beats=None, beat_labels=None, dpwnbeats=None, downbeat_labels=None, downbeat_weight=0, 
             tempo=None, tempo_labels=None, tempo_weight=0, pos_weight=None):
      return nn.MSELoss()(tempo, tempo_labels)

    return {'loss': loss}



In [None]:
# Multitask TCN for joint beat and tempo tracking

class TCN_Multitask(nn.Module):

  def __init__(self, 
               downbeats=False, 
               conv=Conv_original, 
               conv_kwargs={'conv_channels': (16, 16, 16)}, 
               tcn=TCN_original, 
               tcn_kwargs={'num_channels': [16] * 10, 'skip': True},
               meanpool_size=8, 
               dropout=0.1, 
               tempo_dropout=0.5):
    
    super(TCN_Multitask, self).__init__()

    # convolution feature extractor
    self.conv = conv(**conv_kwargs)  
    
    # shared TCN section
    self.tcn = tcn(conv_kwargs['conv_channels'][-1], **tcn_kwargs)

    # beat tracking branch
    self.beat_dropout = nn.Dropout(dropout)

    # Time-distributed fully-connected layer
    self.beat_time_dist = nn.Conv1d(tcn_kwargs['num_channels'][-1], 1, 1)
    # self.beat_sigmoid = nn.Sigmoid()

    self.downbeats = downbeats
    if downbeats:
      # Time-distributed fully-connected layer
      self.downbeat_time_dist = nn.Conv1d(tcn_kwargs['num_channels'][-1], 1, 1)
      # self.downbeat_sigmoid = nn.Sigmoid()

    # tempo branch
    self.tempo_meanpool = nn.AvgPool1d(meanpool_size)
    self.tempo_dropout = nn.Dropout(tempo_dropout)
    self.tempo_time_dist = nn.Conv1d(tcn_kwargs['num_channels'][-1], 1, 1)

    self.init_weights()

  def init_weights(self):
    init_layer(self.beat_time_dist)
    if self.downbeats:
      init_layer(self.downbeat_time_dist)
    init_layer(self.tempo_time_dist)


  def forward(self, x):

    (_, notes, time_len) = x.shape
    # reshape the input into 4D format (batch_size, channels, notes, time_len)
    x = x.view(-1, 1, notes, time_len)

    out = self.conv(x)
    out = out.view(-1, out.shape[1], out.shape[3])

    beats, skips = self.tcn(out)
    tempo = torch.stack(skips).sum(dim=0)

    tempo = self.tempo_meanpool(tempo)
    tempo = self.tempo_dropout(tempo)
    tempo = self.tempo_time_dist(tempo)

    # # use interpolation to get the correct time reolution for target comparison
    # tempo = nn.functional.interpolate(tempo, size=(time_len), mode='nearest')

    # shared dropout
    beats = self.beat_dropout(beats)

    # separate beats dense layer
    beat_pred = self.beat_time_dist(beats)
    # beat_pred = self.beat_sigmoid(beat_pred)
    beat_pred = torch.squeeze(beat_pred)

    downbeat_pred = None
    if self.downbeats: 
      # separate downbeats dense layer
      downbeat_pred = self.downbeat_time_dist(beats)
      # downbeat_pred = self.downbeat_sigmoid(downbeat_pred)
      downbeat_pred = torch.squeeze(downbeat_pred)

    return beat_pred, downbeat_pred, tempo


  @staticmethod
  def loss_functions():
    
    def loss(beats=None, beat_labels=None, downbeats=None, downbeat_labels=None, downbeat_weight=0.5, 
             tempo=None, tempo_labels=None, tempo_weight=0.00025, pos_weight=None):
      beat_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(beats, beat_labels)
      downbeat_loss = 0
      if downbeats:
        downbeat_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight*2)(downbeats, downbeat_labels)
      tempo_loss = nn.MSELoss()(tempo, tempo_labels)
      return beat_loss + downbeat_weight * downbeat_loss + tempo_weight * tempo_loss

    def beat_loss(beats, beat_labels, pos_weight=None):
      return nn.BCEWithLogitsLoss(pos_weight=pos_weight)(beats, beat_labels)

    def downbeat_loss(downbeats, downbeat_labels, pos_weight=None):
      return nn.BCEWithLogitsLoss(pos_weight=pos_weight*2)(downbeats, downbeat_labels)
    
    def tempo_loss(tempo, tempo_labels, pos_weight=None):
      return nn.MSELoss()(tempo, tempo_labels)

    return {'loss': loss, 'beat_loss': beat_loss, 'downbeat_loss': downbeat_loss, 
            'tempo_loss': tempo_loss}



## Network for training loss weights

In [None]:
class MultiTaskLoss(nn.Module):

  def __init__(self, 
               device,
               beats_loss=nn.BCELoss,
               downbeats_loss=None, 
               tempo_loss=None, 
               weights_init=None, 
               beat_manual_weight=1,
               downbeat_manual_weight=1, 
               tempo_manual_weight=1, 
               beat_pos_weight=None, 
               reg_factor_ind=1, 
               reg_factor_sum=1):
    
    super().__init__()

    pos_weight = torch.tensor(beat_pos_weight, requires_grad=False)

    self.beats_loss = None if beats_loss is None else beats_loss(pos_weight=pos_weight)
    self.downbeats_loss = None if downbeats_loss is None else downbeats_loss(pos_weight=pos_weight*2)
    self.tempo_loss = None if tempo_loss is None else tempo_loss()

    if weights_init is None:
      num_weights = 0
      for loss in [beats_loss, downbeats_loss, tempo_loss]:
        if loss is not None:
          num_weights += 1

      self.weights = nn.Parameter(torch.ones(num_weights))

    else: 
      self.weights = nn.Parameter(weights_init)

    manual_weights = [beat_manual_weight]
    if downbeats_loss is not None:
      manual_weights.append(downbeat_manual_weight)
    if tempo_loss is not None:
      manual_weights.append(tempo_manual_weight) 
    
    # for manual (non-learned) loss weights
    self.manual_weights = torch.tensor(manual_weights).to(device)

    # factors for regularisation of loss weights
    self.reg_factor_ind = reg_factor_ind
    self.reg_factor_sum = reg_factor_sum

  
  def forward(self, 
              beats_output=None, 
              beats_labels=None, 
              downbeats_output=None,
              downbeats_labels=None,
              tempo_output=None,
              tempo_labels=None):
    
    losses = []

    if self.beats_loss is not None:
      losses += [self.beats_loss(beats_output, beats_labels)]
    if self.downbeats_loss is not None:
      losses += [self.downbeats_loss(downbeats_output, downbeats_labels)]
    if self.tempo_loss is not None:
      losses += [self.tempo_loss(tempo_output, tempo_labels)]

    # # create weighted multitask loss
    # total_loss = torch.stack(losses) / (self.weights ** 2) * self.manual_weights
    # # add regularisation to prevent degenerate solution
    # total_loss = total_loss + torch.log(1 + (self.weights ** 2))
    # total_loss = total_loss.sum()

    # create weighted multitask loss
    total_loss = torch.stack(losses) * torch.abs(self.weights) * self.manual_weights
    # add regularisation to prevent degenerate solution
    total_loss = total_loss + torch.abs(torch.log(self.weights)) * self.reg_factor_ind
    total_loss = total_loss.sum() + torch.abs(torch.log(self.weights.sum() / 2.0)) * self.reg_factor_sum

    return losses, total_loss



# Utilities

In [None]:
# post-processing

def times_from_labels(labels):
  """
  Get binary labels for computing the f1-score
  """
  labels = np.array(labels)
  positions = np.where(labels == 1)[0]

  return positions / 100.0


def peak_picker(beat_prob_scores, sr=100, mode='max_filt', min_bpm=10.0, max_bpm=360.0, 
                max_filt_length=15, threshold=0.25):
  """
  Post-processing for outputs from a beat tracking network.  
  """
  if mode == 'max_filt':

    # local peak pick
    max_values = maxFilt(beat_prob_scores, max_filt_length)
    beats = np.where(np.array(beat_prob_scores) == max_values, beat_prob_scores, 0)

    # apply threshold and set peaks to 1
    beats = np.where(beats > threshold, 1, 0)

    # change positions to times
    beats = times_from_labels(beats)

  if mode == 'dbn':
    dbn = DBNBeatTrackingProcessor(min_bpm=min_bpm, max_bpm=max_bpm, fps=sr)

    beats = dbn(beat_prob_scores)
  
  return beats




In [None]:
# Graph pianoroll, beat positions or tempo curve

def graph(pianoroll=None, 
          beats=None, 
          beats_pred=None,
          downbeats=None, 
          downbeats_pred=None, 
          tempo=None, 
          tempo_pred=None,
          min_tempo=6,
          sr=100, 
          note_range=(0, 88), 
          time_range=None, 
          tempo_mode='continuous'):
  """
  Utility function for visualising data.  
  """
  # change time range to samples
  if time_range is not None:
    time_range = np.multiply(time_range, sr)

  # plot piano roll
  if pianoroll is not None:
    if time_range is not None:
      pianoroll = pianoroll[:, slice(*time_range)]
    plt.figure(figsize=(15, 5))
    display.specshow(pianoroll, hop_length=1, sr=sr, x_axis='time', y_axis='cqt_note',
                     fmin=pm.note_number_to_hz(note_range[0]))
    plt.show()
  
  # plot tempo
  if tempo is not None:
    if tempo_mode == 'discrete':
      if time_range is not None:
        tempo = tempo[:, slice(*time_range)]
      if tempo_pred is not None:
        time = np.arange(np.array(tempo_pred.shape[1])) / sr
        if time_range is not None:
          tempo_pred = tempo_pred[slice(*time_range)]
          time = time[slice(*time_range)]
        tempo = np.squeeze(np.add(np.argmax(tempo, axis=1), min_tempo))
        plt.plot(time, tempo, 'k')
        plt.plot(time, tempo_pred, 'ro')
      else:
        plt.figure(figsize=(15, 5))
        display.specshow(tempo, hop_length=1, sr=sr, x_axis='time', y_axis='tempo')
      plt.show()
    
    elif tempo_mode == 'continuous':
      plt.figure(figsize=(15, 5))
      if time_range is not None:
        tempo = tempo[slice(*time_range)]
      time = np.arange(len(tempo_pred)) / sr * 8
      if time_range is not None:
        time = time[slice(*time_range)]
      plt.plot(time, tempo, 'k')
      if tempo_pred is not None:
        if time_range is not None:
          tempo_pred = tempo_pred[slice(*time_range)]
        plt.plot(time, tempo_pred, 'ro')      
      plt.show()      

  # plot beats and downbeats
  if beats is not None:
    if time_range is not None:
      beats = beats[slice(*time_range)]
    plt.figure(figsize=(15, 5))
    for beat in beats:
      plt.vlines(beats, ymin=0, ymax=0.5, colors='teal', ls='--')
    if beats_pred is not None:
      if time_range is not None:
        beats_pred = beats_pred[slice(*time_range)]
      plt.vlines(beats_pred, ymin=0.5, ymax=1.0, colors='navy', ls=':')
    if downbeats is not None:
      if time_range is not None:
        downbeats = downbeats[slice(*time_range)]
      plt.vlines(np.nonzero(downbeats)[0] / sr, ymin=0, ymax=0.75, colors='deeppink')
      if downbeats_pred is not None:
        if time_range is not None:
          downbeats_pred = downbeats_pred[slice(*time_range)]
        plt.vlines(np.nonzero(downbeats_pred)[0] / sr, ymin=0.25, ymax=1.0, 
                   colors='darkmagenta', ls=':')
    plt.show()



In [None]:
# set up Dataset and DataLoader

def get_dataloaders(metadata, 
                    sample_len=4349,
                    batch_size=4, 
                    test_ratio=0.1, 
                    val_ratio=0.15, 
                    max_examples=None, 
                    seed=1, 
                    pitch_aug=None, 
                    time_aug=None):
  
  # split into train, validate and test sets
  metatrain, metavalidate, metatest = split_dataset(metadata, 
                                                    test_ratio=test_ratio, 
                                                    validation_ratio=val_ratio, 
                                                    seed=seed, 
                                                    max_examples=max_examples)
  
  # Initialise Dataset
  traindata = myDataset(metadata=metatrain, 
                        sample_length=sample_len, 
                        pitch_aug_range=pitch_aug, 
                        time_aug_range=time_aug)
  
  print(f"Training data: {len(traindata)} sampling windows")

  # create DataLoader
  train_loader = DataLoader(traindata, batch_size=batch_size, num_workers=4, pin_memory=True, 
                            shuffle=True)

  validate_loader = None
  if metavalidate is not None:
    validatedata = myDataset(metadata=metavalidate, 
                             sample_length=sample_len, 
                             pitch_aug_range=pitch_aug, 
                             time_aug_range=time_aug)
    
    print(f"Validation data: {len(validatedata)} sampling windows")

    validate_loader = DataLoader(validatedata, batch_size=batch_size, num_workers=4, pin_memory=True, 
                                 shuffle=True)

  testdata = myDataset(metadata=metatest, 
                       sample_length=sample_len, 
                       pitch_aug_range=pitch_aug, 
                       time_aug_range=time_aug)
  
  print(f"Test data: {len(testdata)} sampling windows")

  test_loader = DataLoader(testdata, batch_size=batch_size, num_workers=4, pin_memory=True, 
                           shuffle=True)

  return train_loader, validate_loader, test_loader



In [None]:
# training loop

def train(num_epochs, 
          model, 
          optimizer,
          train_loader, 
          device, 
          scheduler=None,
          loss_network=None,
          loss_optimizer=None,
          save_name=None,
          log_name=None,
          update_frequency=100,
          evaluation_frequency=5,
          evaluation_loader=None,
          beats=True,
          downbeats=False,
          downbeat_weight=0.5,
          tempo=False, 
          tempo_weight=0.5, 
          tempo_range=(10, 360), 
          peak_mode='max_filt', 
          peak_threshold=0.25,
          max_filt_length=7,
          beat_pos_weight=None):

  # for tracking accuracy for saving the best model
  best_validation_loss = 1000.0

  # get loss functions
  if loss_network is None:
    criteria = model.loss_functions()
    overall_criterion = criteria['loss']

    if beats:
      beat_criterion = criteria['beat_loss']
    if downbeats:
      downbeat_criterion = criteria['downbeat_loss']
    if tempo and beats:
      tempo_criterion = criteria['tempo_loss']
    print("Using static loss weights")
  else:
    loss_network.train()
    loss_network.to(device)
    print("Using trainable loss network")

  # set positive label loss weightings
  pos_weight=None
  if beat_pos_weight is not None:
    pos_weight = torch.Tensor(beat_pos_weight).to(device)

  model.train()
  model.to(device)

  # for graphing
  output_data = {}
  output_data['train_losses'] = []
  output_data['validation_losses'] = []
  if beats:
    output_data['beat_f1s'] = []
    if loss_network is not None:
      output_data['beat loss weight'] = []
    if tempo or downbeats:
      output_data['beat_train_losses'] = []
      output_data['beat_val_losses'] = []
  if downbeats:
    output_data['downbeat_f1s'] = []
    output_data['downbeat_train_losses'] = []
    output_data['downbeat_val_losses'] = []
    if loss_network is not None:
      output_data['downbeat loss weight'] = []
  if tempo:
    output_data['tempo_RMSE'] = []
    if loss_network is not None:
      output_data['tempo loss weight'] = []
    if beats:
      output_data['tempo_train_losses'] = []
      output_data['tempo_val_losses'] = []
    

  for epoch in range(num_epochs):

    running_loss = 0
    running_beat_loss = 0
    running_downbeat_loss = 0
    running_tempo_loss = 0

    running_batch_loss = 0

    for batch_idx, data in enumerate(train_loader):

      # zero the parameter gradients
      optimizer.zero_grad()

      if loss_network is not None:
        loss_optimizer.zero_grad()

      # get samples and ground truth labels
      x = data['pianoroll']
      x = x.to(device)

      # get labels
      beat_labels=None
      if beats:
        beat_labels = data['beats']
        beat_labels = torch.squeeze(beat_labels)
        beat_labels = beat_labels.to(device)
      downbeat_labels = None
      if downbeats:
        downbeat_labels = data['downbeast']
        torch.squeeze(downbeat_labels)
        downbeat_labels = downbeat_labels.to(device)
      tempo_labels = None
      if tempo:
        tempo_labels = data['tempo']
        torch.squeeze(tempo_labels)
        tempo_labels = tempo_labels
      
      # get outputs
      beats_pred, downbeats_pred, tempo_pred = model(x)

      if beats:
        beats_pred = torch.squeeze(beats_pred)
      if downbeats:
        downbeats_pred = torch.squeeze(downbeats_pred)
      if tempo:
        tempo_pred = torch.squeeze(tempo_pred)
      
      # downsample tempo labels
      if tempo:
        tempo_labels = signal.resample(tempo_labels, tempo_pred.shape[1], axis=1)
        tempo_labels = torch.Tensor(tempo_labels).to(device)

      # # calculate losses
      if loss_network is None:
        overall_loss = overall_criterion(beats_pred, beat_labels, downbeats_pred, downbeat_labels, 
                                        downbeat_weight, tempo_pred, tempo_labels, tempo_weight, 
                                        pos_weight=pos_weight)
      else:
        losses, overall_loss = loss_network(beats_pred, beat_labels, downbeats_pred, downbeat_labels, 
                                        tempo_pred, tempo_labels)
      
      running_loss += overall_loss.item() * (beats_pred.shape[0] if beats else tempo_pred.shape[0])
      running_batch_loss += overall_loss.item()

      if downbeats:
        if loss_network is None:
          beat_loss = beat_criterion(beats_pred, beat_labels, pos_weight=pos_weight)
          downbeat_loss = downbeat_criterion(downbeats_pred, downbeat_labels,  
                                             pos_weight=pos_weight)
        else: 
          beat_loss = losses[0].detach()
          downbeat_loss = losses[1].detach()       
        running_beat_loss += beat_loss.item() * beats_pred.shape[0]
        running_downbeat_loss += downbeat_loss.item() * downbeats_pred.shape[0]
      
      if tempo and beats:
        if not downbeats:
          if loss_network is None:
            beat_loss = beat_criterion(beats_pred, beat_labels, pos_weight=pos_weight)
          else:
            beat_loss = losses[0].detach()
          running_beat_loss += beat_loss.item() * beats_pred.shape[0]

        if loss_network is None:
          tempo_loss = tempo_criterion(tempo_pred, tempo_labels, pos_weight=pos_weight)
        else:
          tempo_loss = losses[2].detach() if downbeats else losses[1].detach()
        running_tempo_loss += tempo_loss.item() * tempo_pred.shape[0]
      
      # backward + optimize
      overall_loss.backward()
      optimizer.step()

      if loss_network is not None:
        loss_optimizer.step()

      # update prints
      if update_frequency and (batch_idx + 1) % update_frequency == 0:
        print(f"loss after batch {batch_idx + 1} of epoch {epoch + 1}: " \
              f"{running_batch_loss / (update_frequency * len(train_loader.dataset) / len(train_loader)): .8f}")
        running_batch_loss = 0
    
    # end of epoch prints
    print(f"===== Training loss for epoch {epoch + 1}: {running_loss / len(train_loader.dataset): .8f} =====")

    if downbeats:
      print(f"Beat loss: {running_beat_loss / len(train_loader.dataset): .8f}")
      print(f"Downbeat loss: {running_downbeat_loss / len(train_loader.dataset): .8f}")
    if tempo and beats:
      if not downbeats:
        print(f"Beat loss: {running_beat_loss / len(train_loader.dataset): .8f}")
      print(f"Tempo loss: {running_tempo_loss / len(train_loader.dataset): .8f}")

    # evaluate on validation
    if evaluation_frequency and (epoch + 1) % evaluation_frequency == 0:
      val_loss, val_beat_loss, val_downbeat_loss, val_tempo_loss, \
      val_beat_score, val_downbeat_score, val_tempo_score = evaluate(model, 
                                                                    device,
                                                                    evaluation_loader, 
                                                                    loss_network=loss_network,
                                                                    beats=beats, 
                                                                    downbeats=downbeats, 
                                                                    downbeat_weight=downbeat_weight, 
                                                                    tempo=tempo, 
                                                                    tempo_weight=tempo_weight, 
                                                                    tempo_range=tempo_range, 
                                                                    peak_mode=peak_mode, 
                                                                    peak_threshold=peak_threshold, 
                                                                    beat_pos_weight=beat_pos_weight, 
                                                                    max_filt_length=max_filt_length)
      print(f'\t===== Validation loss: {val_loss: .8f} =====')
      output_data['train_losses'].append(running_loss / len(train_loader.dataset))
      output_data['validation_losses'].append(val_loss)
      if beats:
        print(f'\tValidation beat f-measure: {val_beat_score: .4f}')
        output_data['beat_f1s'].append(val_beat_score)
        if tempo:
          output_data['beat_train_losses'].append(running_beat_loss / len(train_loader.dataset))
          output_data['beat_val_losses'].append(val_beat_loss)
      if downbeats:
        print(f"\tValidation beat loss: {val_beat_loss: .8f}")
        print(f'\tValidation downbeat f-measure: {val_downbeat_score: .4f}')
        print(f"\tValidation downbeat loss: {val_downbeat_loss: .8f}")
        output_data['downbeat_f1s'].append(val_downbeat_score)
        output_data['downbeat_train_losses'].append(running_downbeat_loss / len(train_loader.dataset))
        output_data['downbeat_val_losses'].append(val_downbeat_score)
      if tempo:
        print(f"\tValidation tempo RMSE: {val_tempo_score: .8f}")
        output_data['tempo_RMSE'].append(val_tempo_score)
        if beats:
          if not downbeats:
            print(f"\tValidation beat loss: {val_beat_loss: .8f}")
          print(f"\tValidation tempo loss: {val_tempo_loss: .8f}")
          output_data['tempo_train_losses'].append(running_tempo_loss / len(train_loader.dataset))
          output_data['tempo_val_losses'].append(val_tempo_loss)
      if loss_network is not None:
        for name, param in loss_network.state_dict().items():
          if name == 'weights':
            print(f"Trainable loss parameters: {name}, {param}")
            param = param.cpu()
            output_data['beat loss weight'].append(param[0])
            if downbeats:
              output_data['downbeat loss weight'].append(param[1])
            if tempo:
              output_data['tempo loss weight'].append(param[-1])
        # print(f"Loss weights: {loss_network.weights}")

      # if the best validation performance so far, save the network to file 
      if(val_loss < best_validation_loss):
        best_validation_loss = val_loss
        # get full path
        directory = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio/SavedModels"
        save_path = os.path.join(directory, save_name)
        print('Saving best model')
        torch.save(model.state_dict(), save_path) 

      # update scheduler
      if scheduler is not None:
        lr_before = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)
        lr_after = optimizer.param_groups[0]['lr']
        if lr_before != lr_after:
          print(f"Learning rate reduced to {lr_after}")

  # save output data
  log_root = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio/training_data"
  save_name = save_name + '.pkl'
  log_name = os.path.join(log_root, save_name)

  with open(log_name, 'wb') as f:
    pickle.dump(output_data, f)

  print(list(output_data.keys()))

  return output_data



In [None]:
# evaluation loop

def evaluate(model, 
             device, 
             test_loader, 
             loss_network=None,
             beats=True,
             downbeats=False,
             downbeat_weight=0.5,
             tempo=False, 
             tempo_weight=0.5, 
             tempo_range=(10, 360), 
             peak_mode='max_filt', 
             peak_threshold=0.25, 
             max_filt_length=7,
             beat_pos_weight=None):
  
  model.eval()

  running_loss = 0
  running_beat_loss = 0
  running_downbeat_loss = 0
  running_tempo_loss = 0

  running_beat_score = 0
  running_downbeat_score = 0
  running_tempo_score = 0

  if loss_network is None:
    # get loss functions
    criteria = model.loss_functions()
    overall_criterion = criteria['loss']

    if beats:
      beat_criterion = criteria['beat_loss']
    if downbeats:
      downbeat_criterion = criteria['downbeat_loss']
    if tempo and beats:
      tempo_criterion = criteria['tempo_loss']
  else:
    loss_network.eval()

  # set positive label loss weightings
  if loss_network is None:
    pos_weight = None
    if beat_pos_weight is not None:
      pos_weight = torch.Tensor(beat_pos_weight).to(device)

  with torch.no_grad():
    for batch_idx, data in enumerate(test_loader):

      # get samples and ground truth labels
      x = data['pianoroll']
      x = x.to(device)

      # get labels
      beat_labels = None
      if beats:
        beat_labels = data['beats']
        beat_labels = torch.squeeze(beat_labels)
        beat_labels = beat_labels.to(device)
      downbeat_labels = None
      if downbeats:
        downbeat_labels = data['downbeast']
        torch.squeeze(downbeat_labels)
        downbeat_labels = downbeat_labels.to(device)
      tempo_labels = None
      if tempo:
        tempo_labels = data['tempo']
        torch.squeeze(tempo_labels)
        tempo_labels = tempo_labels
      
      # get outputs
      beats_pred, downbeats_pred, tempo_pred = model(x)

      if beats:
        beats_pred = torch.squeeze(beats_pred)
      if downbeats:
        downbeats_pred = torch.squeeze(downbeats_pred)
      if tempo:
        tempo_pred = torch.squeeze(tempo_pred)

      # downsample tempo labels
      if tempo:
        tempo_labels = signal.resample(tempo_labels, tempo_pred.shape[1], axis=1)
        tempo_labels = torch.Tensor(tempo_labels).to(device)

      # calculate losses
      if loss_network is None:
        overall_loss = overall_criterion(beats_pred, beat_labels, downbeats_pred, downbeat_labels, 
                                        downbeat_weight, tempo_pred, tempo_labels, tempo_weight, 
                                        pos_weight=pos_weight)
      else:
        losses, overall_loss = loss_network(beats_pred, beat_labels, downbeats_pred, downbeat_labels, 
                                        tempo_pred, tempo_labels)
      
      running_loss += overall_loss.item() * (beats_pred.shape[0] if beats else tempo_pred.shape[0])

      if downbeats:
        if loss_network is None:
          beat_loss = beat_criterion(beats_pred, beat_labels, pos_weight=pos_weight)
          downbeat_loss = downbeat_criterion(downbeats_pred, downbeat_labels, 
                                             pos_weight=pos_weight)
        else: 
          beat_loss = losses[0].detach()
          downbeat_loss = losses[1].detach()       
        running_beat_loss += beat_loss.item() * beats_pred.shape[0]
        running_downbeat_loss += downbeat_loss.item() * downbeats_pred.shape[0]
      
      if tempo and beats:
        if not downbeats:
          if loss_network is None:
            beat_loss = beat_criterion(beats_pred, beat_labels, pos_weight=pos_weight)
          else:
            beat_loss = losses[0].detach()
          running_beat_loss += beat_loss.item() * beats_pred.shape[0]

        if loss_network is None:
          tempo_loss = tempo_criterion(tempo_pred, tempo_labels, pos_weight=pos_weight)
        else:
          tempo_loss = losses[2].detach() if downbeats else losses[1].detach()
        running_tempo_loss += tempo_loss.item() * tempo_pred.shape[0]       
      
      # calculate f-scores for beats and downbeats
      if beats:

        # print(f"Eval raw beats pred: {beats_pred}")
        # print(f"Eval raw beats labels: {beat_labels}")

        beats_pred = torch.sigmoid(beats_pred)
        beat_labels = beat_labels.cpu()
        beats_pred = beats_pred.cpu()

        for i in range(beat_labels.shape[0]):
          # convert beat labels to times
          labels = times_from_labels(beat_labels[i, :])
          # peak picking to get predicted beat times
          preds = peak_picker(beats_pred[i, :], sr=100, mode=peak_mode, min_bpm=tempo_range[0], 
                              max_bpm=tempo_range[1], threshold=peak_threshold, 
                              max_filt_length=max_filt_length)

          if preds.size != 0:
            running_beat_score += mir_eval.beat.f_measure(labels, preds)

      if downbeats:
        downbeats_pred = torch.sigmoid(downbeats_pred)
        downbeat_labels = downbeat_labels.cpu()
        downbeats_pred = downbeats_pred.cpu()

        for i in range(downbeat_labels.shape[0]):
          labels = times_from_labels(downbeat_labels[i, :])
          preds = peak_picker(downbeats_pred[i, :], sr=100, mode='max_filt', min_bpm=tempo_range[0], 
                              max_bpm=tempo_range[1], threshold=peak_threshold, 
                              max_filt_length=max_filt_length)
          
          if preds.size != 0:
            running_downbeat_score += mir_eval.beat.f_measure(labels, preds)

      if tempo:
        running_tempo_score += torch.sum(torch.abs(tempo_pred - tempo_labels)).item() / tempo_pred.shape[1]
      
  running_loss /= len(test_loader.dataset)
  running_beat_loss /= len(test_loader.dataset)
  running_downbeat_loss /= len(test_loader.dataset)
  running_tempo_loss /= len(test_loader.dataset)

  running_beat_score /= len(test_loader.dataset)
  running_downbeat_score /= len(test_loader.dataset)
  running_tempo_score = running_tempo_score / len(test_loader.dataset) * 60

  return running_loss, running_beat_loss, running_downbeat_loss, running_tempo_loss, \
          running_beat_score, running_downbeat_score, running_tempo_score
    


In [None]:
# load model
def load_state(model, load_name):

  directory = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio/SavedModels" 
  path = os.path.join(directory, load_name)

  model.load_state_dict(torch.load(path))

  return model


# load model output data
def load_data(name):
  directory = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio/training_data" 
  name = name + '.pkl'
  name = os.path.join(directory, name)

  with open(name, 'rb') as f:
    data = pickle.load(f)

  return data

  

In [None]:
# For plotting learning curves and evaluation metrics

def plot_training(update_rate, data, metrics, trim=0):

  if type(data) == list:
    num_plots = len(data)
  else:
    num_plots = 1
    data = [data]

  figwidth = min(20, 5 * len(data))

  fig, ax = plt.subplots(1, num_plots, figsize=(figwidth, 4))

  x_points = np.arange(len(data[0][metrics[0]])) * update_rate

  if num_plots > 1:
    for i, col in enumerate(ax):
      for metric in metrics:
        col.plot(x_points[trim:], data[i][metric][trim:], label=metric)

        col.set_xlabel('epochs')
        col.legend()
  else:
    for metric in metrics:
      ax.plot(x_points[trim:], data[0][metric][trim:], label=metric)

      ax.set_xlabel('epochs')
      ax.legend()

  plt.show()



In [None]:
# Testing

def test(model, 
         device, 
         test_loader, 
         tempo_model=None,
         sr=22050,
         num_samples=1, 
         beats=True, 
         downbeats=False, 
         tempo=True, 
         tempo_range=(10, 360), 
         peak_mode='max_filt', 
         peak_threshold=0.25, 
         max_filt_length=7, 
         folder=None):

  """
  Test a model on input samples.  
  Synthesize to audion and plot results.  

  Assumes a batch size of 1.  
  """

  # ASAP dataset root
  asap_directory = "./asap-dataset"

  model.eval()
  model.to(device)

  if tempo_model is not None:
    tempo_model.eval()
    tempo_model.to(device)

  with torch.no_grad():
    for idx, data in enumerate(test_loader):

      if num_samples and idx >= num_samples:
        break
      
      x = data['pianoroll']
      x = x.to(device)

      assert x.shape[0] == 1, "Batch size must be 1 for testing"

      # get predictions
      beats_pred, downbeats_pred, tempo_pred = model(x)

      if tempo_model is not None:
        _, _, tempo_pred = tempo_model(x)

      # get labels
      beat_labels = None
      if beats:
        beat_labels = torch.squeeze(data['beats']).numpy()
      downbeat_labels = None
      if downbeats:
        downbeat_labels = torch.squeeze(data['downbeast']).numpy()
      tempo_labels = None
      if tempo:
        tempo_labels = data['tempo']

      # get MIDI track
      path = data['original_name'][0]
      print(f"Testing track {path} ... ")
      piece = pm.PrettyMIDI(os.path.join(asap_directory, path))

      # synthesize audio
      waveform = piece.synthesize(fs=sr, wave=signal.square)
      start = data['start'][0].item()
      start = int(start / 100 * sr)
      end = int(4349 / 100 * sr)
      waveform = waveform[start: start + end]

      print(start)

      # get click track
      beat_preds = peak_picker(beats_pred, sr=100, mode=peak_mode, min_bpm=tempo_range[0], 
                          max_bpm=tempo_range[1], threshold=peak_threshold, 
                          max_filt_length=max_filt_length)
      clicks = librosa.clicks(times=beat_preds, sr=sr)

      # put click onto audio
      audio_len = len(waveform)
      if len(clicks) > audio_len:
        clicks = clicks[:audio_len]
      elif len(clicks) < audio_len:
        pad = np.zeros(audio_len - len(clicks))
        clicks = np.concatenate((clicks, pad))
      output = waveform + clicks

      # save audio track
      save_dir = "/content/gdrive/MyDrive/Colab Notebooks/QM DL for music and audio/audio"
      if folder:
        save_dir = os.path.join(save_dir, folder)
      save_name = data['name'][0][:-4] + 'wav'
      save_path = os.path.join(save_dir, save_name)
      sf.write(save_path, output, sr)

      # calculate metrics
      labels = times_from_labels(beat_labels)
      beat_f1 = mir_eval.beat.f_measure(labels, beat_preds)

      if downbeats:
        dblabels = times_from_labels(downbeat_labels)
        downbeat_preds = peak_picker(downbeats_pred, sr=100, mode=peak_mode, min_bpm=tempo_range[0], 
                                     max_bpm=tempo_range[1], threshold=peak_threshold, 
                                     max_filt_length=max_filt_length)
        downbeat_f1 = mir_eval.beat.f_measure(dblabels, downbeat_preds)       
      
      tempo_pred = torch.squeeze(tempo_pred).numpy()
      tempo_labels = signal.resample(tempo_labels, tempo_pred.shape[0], axis=1)
      tempo_accuracy = np.mean(np.abs(tempo_pred - tempo_labels))

      print(f"Beat f1: {beat_f1}")
      if downbeats:
        print(f"Downbeat f1: {downbeat_f1}")
      print(f"Tempo Accuracy: {tempo_accuracy}")

      x = x.numpy()
      x = np.squeeze(x)
      tempo_labels = np.squeeze(tempo_labels) * 60
      tempo_pred = np.squeeze(tempo_pred) * 60

      # show graphs
      graph(pianoroll=x, 
            beats=None, 
            beats_pred=None,
            downbeats=None, 
            downbeats_pred=None, 
            tempo=None, 
            tempo_pred=None,
            min_tempo=10,
            sr=100, 
            note_range=(0, 88), 
            time_range=None)
      
      graph(pianoroll=None, 
            beats=labels, 
            beats_pred=beat_preds,
            downbeats=None, 
            downbeats_pred=None, 
            tempo=None, 
            tempo_pred=None,
            min_tempo=10,
            sr=100, 
            note_range=(0, 88), 
            time_range=None)

      if downbeats:
        graph(pianoroll=None, 
              beats=None, 
              beats_pred=None,
              downbeats=downbeat_labels, 
              downbeats_pred=downbeat_preds, 
              tempo=None, 
              tempo_pred=None,
              min_tempo=10,
              sr=100, 
              note_range=(0, 88), 
              time_range=None)

      graph(pianoroll=None, 
            beats=None, 
            beats_pred=None,
            downbeats=None, 
            downbeats_pred=None, 
            tempo=tempo_labels, 
            tempo_pred=tempo_pred,
            min_tempo=10,
            sr=100, 
            note_range=(0, 88), 
            time_range=None)

