<a href="https://colab.research.google.com/github/oujianshen/Audero-Audio-Player/blob/master/yodo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**YODO - You Only Drum Once**

YODO is an Automatic Drum Transcription (ADT) system, designed to transform audio recordings of drums/percussion instruments into musical notation, specifically as a MIDI file. More info at: https://github.com/varsaav/yodo

To use our system, go to the upper menu bar and press:

**`Runtime -> Run all`**

You'll be asked to upload the audio(s) you want to transcribe, and down the line you can also fiddle with some parameters, but that's merely optional. When everything's done, the MIDI transcription(s) will download automatically. Have fun!

In [None]:
#@title Install dependencies
from google.colab import files, output
import os, shutil

# Dependencies
!pip install demucs
!pip install mido

# Clone YOLO implementation
!git clone https://github.com/AlexeyAB/darknet

# Setup for YOLO installation
%cd darknet
!sed -i 's/OPENCV=0/OPENCV=1/' Makefile
!sed -i 's/GPU=0/GPU=1/' Makefile
!sed -i 's/CUDNN=0/CUDNN=1/' Makefile
!sed -i 's/CUDNN_HALF=0/CUDNN_HALF=1/' Makefile

# In some edge cases labels may be outside the boundaries of the image
# It's usually just a pixel off, but it can cause problems, so...
# I commented the code line that checks for that
!sed -i 's/assert(x < m.w && y < m.h && c < m.c)/\/\/assert(x < m.w && y < m.h && c < m.c)/' src/image.c

# YOLO installation
!make --quiet
%cd ..

# Pull source code
!git clone https://ValenVS@github.com/varsaav/yodo
%cd yodo/model
!gdown https://drive.google.com/uc?id=1y8aDvM807gsPwPnjFBMVI4A8Otgv3QSj
%cd ../..

# Cleanup storage in Colab VM
shutil.rmtree('/content/sample_data')
os.mkdir('/content/separated')
os.mkdir('/content/spectrograms')
os.mkdir('/content/cropped')
os.mkdir('/content/labels')
os.mkdir('/content/midi')

output.clear()

In [None]:
# Upload one or more audio files
files.upload()
output.clear()

In [None]:
#@title Separate drum tracks

from pathlib import Path

for filepath in os.listdir('./'):
    # Remove special symbols to avoid path problems later
    special = ["'",'"','#','¡','!','$','%','&','¿','?']
    new_name = filepath.replace(' ', '_')
    for symbol in special:
        new_name = new_name.replace(symbol, '')
    os.rename(filepath, new_name)

    if os.path.isfile(new_name):
        !python -m demucs.separate -o separated -n mdx_extra "$new_name"

        filename = Path(new_name).stem
        shutil.move(f'separated/mdx_extra/{filename}/drums.wav',
            f'separated/{filename}.wav')
        shutil.rmtree('separated/mdx_extra')

output.clear()

In [None]:
#@title Generate spectrograms
import imageio
import numpy as np

from yodo import audio
from yodo import spectrogram as sp
from yodo.utils import rescale

# Superlets parameters
foi = sp.generate_foi()
orders = np.round(np.linspace(1, 10, len(foi)), 0)
n_cycles = 3
superlets = {}

for filename in os.listdir('separated'):
    # ipynb checkpoints must be ignored
    if filename.startswith('.'):
        continue

    name = filename.split('.')[0]
    samplerate, samples = audio.read_wav(f'separated/{name}.wav')

    # Multichannel to mono
    if np.ndim(samples) > 1:
        samples = np.mean(samples, axis=1)

    # Maps values to range [-1, 1]
    samples = audio.normalize(samples)

    # Save wavelets for reuse (speeds up computation)
    if samplerate not in superlets:
        superlets[samplerate] = sp.Superlets(
            samplerate, foi, n_cycles, orders)

    begin, end = 0, min(samplerate * 30, len(samples))
    slt = superlets[samplerate]
    specs = []

    # Calculate spectrograms in 30-seconds chunks (for large audio files)
    while begin < len(samples):
        ms_per_frame = 10
        fps = int(1000 / ms_per_frame)

        spec = sp.faslt(samples[begin:end], orders, slt, fps)[0]
        norm = rescale(audio.power_to_db(spec))
        specs.append(norm)

        begin = end
        end = min(begin + samplerate * 30, len(samples))

    # Merge chunks and save
    rgb = (np.hstack(specs) * 255).astype('uint8')
    imageio.imwrite(f'spectrograms/{name}.png', rgb)

In [None]:
#@title Crop spectrograms
import cv2

for filename in os.listdir('spectrograms'):
    # ipynb checkpoints must be ignored
    if filename.startswith('.'):
        continue

    name = filename.split('.')[0]

    filepath = f'spectrograms/{name}.png'
    image = imageio.v2.imread(filepath)
    height, width = np.shape(image)

    if not os.path.exists(f'cropped/{name}'):
        os.mkdir(f'cropped/{name}')

    begin = 0
    end = min(1024, width)
    stride = 1024 - 100

    i = 1
    while begin < width:
        cropped = cv2.resize(image[:, begin:end], (1024, 128))
        index = str(i).zfill(3)
        imageio.imwrite(f'cropped/{name}/{index}.png', cropped)

        if end == width:
            break
        else:
            begin += stride
            end = min(begin + 1024, width)
            i += 1

In [None]:
#@title Predict onsets

# If locale is not UTF-8 some commands fail
import locale
locale.getpreferredencoding = lambda: "UTF-8"

cfg = 'yodo/model/tri2.cfg'
data = 'yodo/model/classes.data'
weights = 'yodo/model/tri2-80k.weights'

for dirname in os.listdir('cropped'):
    # Save image paths
    paths = []
    for filename in sorted(os.listdir(f'cropped/{dirname}')):
        paths.append(f'cropped/{dirname}/{filename}')

    with open(f'labels/{dirname}.input', 'w') as file:
        file.write('\n'.join(paths))

    image_paths = f'labels/{dirname}.input'
    labels_path = f'labels/{dirname}.json'
    !darknet/darknet detector test $data $cfg $weights \
        < $image_paths -thresh 0.25 -dont_show -out $labels_path

    for filename in os.listdir(f'cropped/{dirname}'):
        if filename.endswith('.txt'):
            shutil.move(
                f'cropped/{dirname}/{filename}',
                f'labels/{dirname}/{filename}')

try:
    os.remove('bad.list')
    os.remove('predictions.jpg')
except:
    pass

output.clear()

In [None]:
#@title Parse predictions
import json

rects = {}
for filename in os.listdir('labels'):
    if filename.endswith('.json'):
        dirname = filename.split('.')[0]
        rects[dirname] = {}
        with open(f'labels/{filename}') as file:
            data = json.load(file)

        for subspec in data:
            stride = 1024 - 100
            index = int(subspec['frame_id'])
            offset = stride * (index - 1)

            for prediction in subspec['objects']:
                Class = prediction['class_id']
                rect = prediction['relative_coordinates']
                prob = prediction['confidence']

                left_x = rect['center_x'] - (rect['width'] / 2)
                top_y = rect['center_y'] - (rect['height'] / 2)

                x = int(left_x * 1024) + offset
                y = int(top_y * 120)
                w = int(rect['width'] * 1024)
                h = int(rect['height'] * 120)

                if Class in rects[dirname]:
                    rects[dirname][Class].append((prob, x, y, w, h))
                else:
                    rects[dirname][Class] = [(prob, x, y, w, h)]

## **Confidence Thresholds**

If you notice there are a lot of false positives for a given instrument, you can **raise** the threshold to only allow the best predictions to remain.

If the system fails to detect most instances of a given instrument, you can **lower** the threshold to allow more predictions, even if less precise.

The default values were determined experimentally, but they might not be the best for your use case. *Try moving things around!*

In [None]:
#@title Re-run to go back to the default thresholds

import ipywidgets as widgets

# Default thresholds for tri2-80k, evaluated on the validation set
sliders = [
    widgets.IntSlider(value=65, max=100, description='Kick'),
    widgets.IntSlider(value=70, max=100, description='Snare'),
    widgets.IntSlider(value=68, max=100, description='Cross Stick'),
    widgets.IntSlider(value=78, max=100, description='High Tom'),
    widgets.IntSlider(value=82, max=100, description='Mid Tom'),
    widgets.IntSlider(value=72, max=100, description='Low Tom'),
    widgets.IntSlider(value=52, max=100, description='Closed Hihat'),
    widgets.IntSlider(value=67, max=100, description='Open Hihat'),
    widgets.IntSlider(value=60, max=100, description='Ride'),
    widgets.IntSlider(value=41, max=100, description='Bell'),
    widgets.IntSlider(value=79, max=100, description='Crash'),
]
for slider in sliders:
    display(slider)

IntSlider(value=65, description='Kick')

IntSlider(value=70, description='Snare')

IntSlider(value=68, description='Cross Stick')

IntSlider(value=78, description='High Tom')

IntSlider(value=82, description='Mid Tom')

IntSlider(value=72, description='Low Tom')

IntSlider(value=52, description='Closed Hihat')

IntSlider(value=67, description='Open Hihat')

IntSlider(value=60, description='Ride')

IntSlider(value=41, description='Bell')

IntSlider(value=79, description='Crash')

In [None]:
#@title Prepare data for onset and velocity estimation

classes = ['Kick','Snare','Cross Stick','High Tom','Mid Tom',
           'Low Tom','Closed Hi-Hat','Open Hi-Hat',
           'Ride Cymbal','Ride Bell','Crash Cymbal']

# Velocity estimation method by class
methods = [2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2]

# Exponential regression parameters for velocity estimation
A = [0.46, 3.02, 9.14, 4.87, 14.09,
     4.67, 14.49, 7.01, 9.78, 9.23, 5.21]
B = [1290.06, 187.41, 85.18, 110.87, 22.13,
     113.53, 52.11, 294.61, 103.04, 97.75, 193.98]

# MIDI pitches associated with each class
pitches = [35, 40, 37, 48, 45, 43, 42, 46, 51, 53, 49]


# Templates
templates = []
for c in range(11):
    template = imageio.v2.imread(f'yodo/templates/{c}.png')[:, :, 0]/255.
    templates.append(template)

# Adjust crash cymbal
templates[10] -= templates[0]
templates[10] = np.maximum(templates[10], 0)

# Adjust template width
widths = [28, 21, 17, 95, 104, 94, 15, 22, 14, 32, 60]
for c in range(11):
    templates[c] = templates[c][:, 0:widths[c]]

# Binary masks
masks = []
for c in range(11):
    masks.append(templates[c].copy())
    masks[c][masks[c] >= 0.4] = 1
    masks[c][masks[c] <  0.4] = 0

In [None]:
#@title Estimate onset/velocity and export to MIDI
from mido import Message, MidiFile, MidiTrack
from scipy.signal import correlate2d as conv2d
import pandas as pd

thresholds = []
for slider in sliders:
    thresholds.append(slider.value)

for (filename, data) in rects.items():
    image = imageio.v2.imread(f'spectrograms/{filename}.png')
    height, width = np.shape(image)

    padded = np.zeros((height, width + 110))
    padded[:, 5:-105] = image

    onsets = []
    classes = []
    velocities = []
    for (c, labels) in data.items():
        c = int(c)

        # Sometimes predictions are outside the expected range
        # There's a fix for that below
        boxes = [(max(x, 0), x + w, max(y, 0), y + h) \
            for (p, x, y, w, h) in labels \
            if float(p)*100 >= thresholds[c] \
            and x + np.shape(templates[c])[1] + 5 < width] # Fix

        if len(boxes) == 0:
            continue

        # Onset estimation
        margin = 5
        X1 = np.asarray([x1 for (x1, x2, y1, y2) in boxes])
        X2 = X1 + np.shape(templates[c])[1] + margin
        Y1 = np.asarray([y1 for (x1, x2, y1, y2) in boxes])
        Y2 = np.asarray([y2 for (x1, x2, y1, y2) in boxes])

        convs = [conv2d(padded[y1:y2, x1:x2], templates[c][y1:y2],
            mode='valid')[0] for x1, x2, y1, y2 in zip(X1, X2, Y1, Y2)]

        offsets = np.argmax(convs, axis=1)
        X1 += np.asarray(offsets)
        X2 = X1 + np.shape(templates[c])[1]

        if methods[c] == 1:
            L = [np.max(padded[y1:y2, x1:x2] * templates[c][y1:y2]) \
                for x1, x2, y1, y2 in zip(X1, X2, Y1, Y2)]
        else:
            L = [np.max(padded[y1:y2, x1:x2] * masks[c][y1:y2]) \
                for x1, x2, y1, y2 in zip(X1, X2, Y1, Y2)]

        # Adjust to 0-1 range, since images are RGB (0-255)
        L = np.asarray(L)/255.

        # Exponential regression
        V = np.minimum(127, A[c] * (B[c] ** np.asarray(L)))
        C = np.zeros(len(V), dtype=int) + c

        onsets.extend(X1)
        classes.extend(C)
        velocities.extend(V)

    df = pd.DataFrame()
    df['x'] = onsets
    df['class'] = classes
    df['velocity'] = velocities

    # Remove duplicates caused by overlapping images
    df.sort_values(by=['x','velocity'], inplace=True)
    df.drop_duplicates(subset=['x','class'], keep='last', inplace=True)

    # X coordinates to MIDI ticks
    bpm = 120
    beats_per_sec = bpm / 60
    ticks_per_beat = 960
    ticks_per_sec = ticks_per_beat * beats_per_sec

    fps = 100
    df['abs_ticks'] = df['x'].apply(lambda x: (x / fps) * ticks_per_sec)
    df['rel_ticks'] = df['abs_ticks'] - df['abs_ticks'].shift(1).bfill(0)

    # Classes to MIDI pitches
    df['pitch'] = df['class'].apply(lambda x: pitches[x])

    # Export as MIDI
    mid = MidiFile(type=0, ticks_per_beat=960)
    track = MidiTrack()
    mid.tracks.append(track)
    track.append(Message('program_change', channel=9, program=0, time=0))

    midi_data = zip(df['rel_ticks'], df['pitch'], df['velocity'])
    for (delta, pitch, vel) in midi_data:
        track.append(Message('note_on', channel=9,
            note=int(pitch), velocity=int(vel), time=int(delta)))

    mid.save(f'midi/{filename}.mid')
    files.download(f'midi/{filename}.mid')