In [None]:
!pip install python_speech_features
!pip install hmmlearn
!pip install mido
!pip install mir-eval
!pip install pretty-midi
!pip install midi2audio



In [None]:
import json
import os
import mido
from midi2audio import FluidSynth
from concurrent.futures import ProcessPoolExecutor
from collections import defaultdict
from python_speech_features import mfcc
import scipy.io.wavfile as wav
from concurrent.futures import ThreadPoolExecutor
from hmmlearn import hmm
import numpy as np
from google.colab import drive
import IPython.display as ipd
from scipy.io import wavfile
from joblib import Parallel, delayed
from numba import njit
import pretty_midi
import mir_eval
import copy
import librosa
import seaborn as sns
import pandas as pd
import wave
import plotly.graph_objects as go

### (This code is given with the dataset to compute the evaluation (p, r, f1)

In [None]:

EVAL_TOLERANCE = 0.05
OCTAVE_INVARIANT_RADIUS = 16


def trim_midi(ref_midi_data, est_midi_data):
    ref_notes = []
    for i in ref_midi_data.instruments:
        if i.is_drum:
            continue
        for n in i.notes:
            ref_notes.append(n)
    segment_start = ref_notes[0].start
    segment_end = ref_notes[-1].end

    num_dropped = 0
    for i in est_midi_data.instruments:
        if i.is_drum:
            continue
        i.notes = [
            n for n in i.notes if n.start >= segment_start and n.start <= segment_end
        ]

    return est_midi_data


def midi_to_mir_eval(midi_data, dummy_offsets = True):
    notes = []
    for i in midi_data.instruments:
        if i.is_drum:
            continue
        for n in i.notes:
            notes.append((n.start, n.end, n.pitch))
    notes = sorted(notes)
    note_onsets = [s for s, _, _ in notes]
    note_offsets = [e for _, e, _ in notes]
    if dummy_offsets and len(note_onsets) > 0:
        note_offsets = note_onsets[1:] + [note_onsets[-1] + 1]
    intervals = np.stack([note_onsets, note_offsets], axis = 1).astype(np.float64)
    pitches = np.array([p for _, _, p in notes], dtype = np.int64)
    return intervals, pitches


def extract_notes(ref_midi_file, est_midi_file):
    ref_midi_data = pretty_midi.PrettyMIDI(ref_midi_file)
    est_midi_data = pretty_midi.PrettyMIDI(est_midi_file)
    ref_midi_data = copy.deepcopy(ref_midi_data)
    est_midi_data = copy.deepcopy(est_midi_data)

    est_midi_data = trim_midi(ref_midi_data, est_midi_data)

    ref_intervals, ref_pitches = midi_to_mir_eval(ref_midi_data, dummy_offsets = False)
    est_intervals, est_pitches = midi_to_mir_eval(est_midi_data, dummy_offsets = False)

    return ref_intervals, ref_pitches, est_intervals, est_pitches


def mir_eval_onset_prf(ref_intervals, ref_pitches, est_intervals, est_pitches):
    m_to_f = lambda m: 440.0 * np.power(2, (m.astype(np.float32) - 69) / 12)
    p, r, f1, _ = mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals,
            m_to_f(ref_pitches),
            est_intervals,
            m_to_f(est_pitches),
            onset_tolerance = EVAL_TOLERANCE,
            pitch_tolerance = 1.0,
            offset_ratio = None,
        )
    return p, r, f1

def mir_eval_onset_prf_pitch_only(ref_intervals, ref_pitches, est_intervals, est_pitches):
    m_to_f = lambda m: 440.0 * np.power(2, (m.astype(np.float32) - 69) / 12)
    p, r, f1, _ = mir_eval.transcription.precision_recall_f1_overlap(
            ref_intervals,
            m_to_f(ref_pitches),
            est_intervals,
            m_to_f(est_pitches),
            onset_tolerance = 5,
            pitch_tolerance = 1.0,
            offset_ratio = None,
        )
    return p, r, f1



def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches):
    octaves = list(range(-OCTAVE_INVARIANT_RADIUS, OCTAVE_INVARIANT_RADIUS + 1))
    ps = []
    rs = []
    f1s = []
    for o in octaves:
        p, r, f1 = mir_eval_onset_prf(
            ref_intervals,
            (o * 12) + ref_pitches,
            est_intervals,
            est_pitches
        )
        ps.append(p)
        rs.append(r)
        f1s.append(f1)

    best_octave_idx = np.argmax(f1s)
    return (
        ps[best_octave_idx],
        rs[best_octave_idx],
        f1s[best_octave_idx]
    )



# We load the data

In [None]:
drive.mount('/content/drive')

midi_data_path = '/content/drive/My Drive/MLSP_PROJECT/midi_data'
wav_data_path = '/content/drive/My Drive/MLSP_PROJECT/wav_data_sync_with_midi'
onset_train_data_path = '/content/drive/My Drive/TRAIN'
onset_test_data_path = '/content/drive/My Drive/TEST'

print('MIDI data:', os.listdir(midi_data_path))
print('WAV data:', os.listdir(wav_data_path))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
MIDI data: ['F02_0399_0001_1_D.mid', 'F02_0399_0002_1_D.mid', 'F02_0399_0002_2_D.mid', 'F02_0399_0001_2_D.mid', 'F02_0400_0001_2_D.mid', 'F02_0401_0001_2_D.mid', 'F02_0400_0002_1_D.mid', 'F02_0400_0002_2_D.mid', 'F02_0401_0001_1_D.mid', 'F02_0401_0002_1_D.mid', 'F02_0403_0001_1_D.mid', 'F02_0402_0001_1_D.mid', 'F02_0402_0002_2_D.mid', 'F02_0402_0002_1_D.mid', 'F02_0401_0002_2_D.mid', 'F02_0402_0001_2_D.mid', 'F02_0403_0002_1_D.mid', 'F02_0404_0001_1_D.mid', 'F02_0403_0002_2_D.mid', 'F02_0403_0003_2_D.mid', 'F02_0403_0003_1_D.mid', 'F02_0403_0001_2_D.mid', 'F02_0404_0003_2_D.mid', 'F02_0404_0002_1_D.mid', 'F02_0404_0003_1_D.mid', 'F02_0404_0001_2_D.mid', 'F02_0404_0002_2_D.mid', 'F02_0405_0001_1_D.mid', 'F02_0405_0002_1_D.mid', 'F02_0405_0001_2_D.mid', 'F02_0406_0001_1_D.mid', 'F02_0406_0001_2_D.mid', 'F02_0405_0002_2_D.mid', 'F02_0406_0002_1_D.mid', 'F02_0407

In [None]:

json_file_path = '/content/drive/My Drive/MLSP_PROJECT/train_valid_test_keys.json'


with open(json_file_path, 'r') as f:
    data = json.load(f)

train_files = data['TRAIN']
valid_files = data['VALID']
test_files = data['TEST']

midi_train = sorted([file for file in os.listdir(midi_data_path) if file[:-4] in train_files])
midi_valid = sorted([file for file in os.listdir(midi_data_path) if file[:-4] in valid_files])
midi_test = sorted([file for file in os.listdir(midi_data_path) if file[:-4] in test_files])

wav_train = sorted([file for file in os.listdir(wav_data_path) if file[:-4] in train_files])
wav_valid = sorted([file for file in os.listdir(wav_data_path) if file[:-4] in valid_files])
wav_test = sorted([file for file in os.listdir(wav_data_path) if file[:-4] in test_files])

onset_train = sorted([file for file in os.listdir(onset_train_data_path) if file[:-4] in train_files])
onset_test = sorted([file for file in os.listdir(onset_test_data_path) if file[:-4] in test_files])

# We build the HMM

In [None]:
def calculate_note_counts(file, midi_data_path):
  note_counts = dict()
  transition_counts = dict()
  mid = mido.MidiFile(os.path.join(midi_data_path, file))
  last_note = None
  for msg in mid:
    if not msg.is_meta and msg.type == 'note_on' and msg.velocity != 0:
      note = msg.note%12
      note_counts[note] = note_counts.get(note, 0) + 1
      if last_note is not None:
        if last_note not in transition_counts:
          transition_counts[last_note] = dict()
        transition_counts[last_note][note] = transition_counts[last_note].get(note, 0) + 1
      last_note = note
  return note_counts, transition_counts
def calculate_probabilities(midi_files, midi_data_path):
  note_counts = defaultdict(int)
  transition_counts = defaultdict(lambda: defaultdict(int))
  with ProcessPoolExecutor() as executor:
    futures = [executor.submit(calculate_note_counts, file, midi_data_path) for file in midi_files]
    for future in futures:
      file_note_counts, file_transition_counts = future.result()
      for note, count in file_note_counts.items():
        note_counts[note] += count
      for note, next_notes in file_transition_counts.items():
        for next_note, count in next_notes.items():
          transition_counts[note][next_note] += count
  transition_probs = defaultdict(lambda: defaultdict(float))
  for note in transition_counts:
    total_transitions = sum(transition_counts[note].values())
    for next_note in transition_counts[note]:
      transition_probs[note][next_note] = transition_counts[note][next_note] / total_transitions
  total_notes = sum(note_counts.values())
  initial_probs = {note: count / total_notes for note, count in note_counts.items()}
  assert abs(sum(initial_probs.values()) - 1) < 1e-6, "La somme des probabilités initiales n'est pas égale à 1"
  unique_notes = len(note_counts)
  return initial_probs, transition_probs, unique_notes
initial_probs, transition_probs, unique_notes = calculate_probabilities(midi_train, midi_data_path)

In [None]:
def calculate_average_note_length(file):
    mid = mido.MidiFile(os.path.join(midi_data_path, file))
    total_time = 0
    note_count = 0
    for track in mid.tracks:
        time = 0
        tempo = 500000
        for msg in track:
            if msg.type == 'set_tempo':
                tempo = msg.tempo
            time_in_seconds = mido.tick2second(msg.time, mid.ticks_per_beat, tempo)
            time += time_in_seconds
            if msg.type == 'note_on' and msg.velocity != 0:
                note_count += 1
        total_time += time
    return total_time / note_count if note_count > 0 else 0
average_note_lengths = Parallel(n_jobs=-1)(delayed(calculate_average_note_length)(file) for file in midi_train)
average_of_averages = sum(average_note_lengths) / len(average_note_lengths)
print(average_of_averages)


0.4696460204397361


In [None]:
def calculate_mfcc_for_file(file, data_path, numcep=1, winlen=average_of_averages, winstep=average_of_averages):
  rate, sig = wav.read(os.path.join(data_path, file))
  mfcc_features = mfcc(sig, rate, numcep=numcep, winlen=winlen, winstep=winstep, nfft=int(44100*winlen)+1)
  return mfcc_features
def calculate_mfcc(audio_files, data_path, numcep=1, num_threads=8, winlen=average_of_averages, winstep=average_of_averages):
  mfccs = Parallel(n_jobs=num_threads)(delayed(calculate_mfcc_for_file)(file, data_path, numcep, winlen, winstep) for file in audio_files)
  return mfccs

In [None]:
mfcc_train = calculate_mfcc(wav_train, wav_data_path)

In [None]:
valid_keys = list(set(initial_probs.keys()).union(set(transition_probs.keys())))
n_components = len(valid_keys)
model = hmm.GaussianHMM(n_components=n_components, covariance_type="full", init_params='mc')
model.startprob_ = np.array([initial_probs[key] for key in valid_keys])
model.transmat_ = np.array([[transition_probs[i].get(j, 0) for j in valid_keys] for i in valid_keys])
model.fit(np.concatenate(mfcc_train).ravel().reshape(-1, 1))



# Test Part

Computing est_intervals : (We get the correct using Librosa)




In [None]:
wav_length = []

for wav_file in wav_test:
    with wave.open(os.path.join(wav_data_path, wav_file), 'r') as f:
        frames = f.getnframes()
        framerate = f.getframerate()
        duration = frames / float(framerate)
        wav_length.append(duration)

onset_test_data_path = '/content/drive/My Drive/TEST'
est_intervals = []

for file in sorted(os.listdir(onset_test_data_path)):
    with open(os.path.join(onset_test_data_path, file), 'r') as f:
        lines = f.readlines()
        times = [float(line.strip()) for line in lines]
        intervals = []
        for i in range(0 ,len(times)-1):
            if times[i] != times[i+1]:
                intervals.append([times[i], times[i+1] - 0.001])
        intervals.append([intervals[-1][1]+0.001 , wav_length[j]])
        est_intervals.append(intervals)


Computing est_interval : (using the HMM)

In [None]:
def predict_notes_for_file(model, file, data_path, valid_keys):
    mfcc_list = calculate_mfcc_for_file(file, data_path)
    mfcc = np.concatenate(mfcc_list)
    states = model.predict(mfcc.reshape(-1, 1))
    notes = [valid_keys[state] for state in states]
    return notes

est_pitches = []

for test_file in wav_test:
    predicted_notes = predict_notes_for_file(model, test_file, wav_data_path, valid_keys)
    est_pitches.append(np.array(predicted_notes))


Getting ref_pitches and ref_intervals (we get them from the midi files)

In [None]:
def extract_pitches_and_intervals(midi_file):
    mid = pretty_midi.PrettyMIDI(os.path.join(midi_data_path, midi_file))
    pitches = []
    intervals = []

    for note in mid.instruments[0].notes:
        pitches.append(note.pitch % 12)
        intervals.append([note.start, note.end])

    return np.array(pitches), np.array(intervals)

ref_intervals = []
ref_pitches = []

with ProcessPoolExecutor() as executor:
    for pitches, intervals in executor.map(extract_pitches_and_intervals, midi_test):
        ref_pitches.append(pitches)
        ref_intervals.append(intervals)


In [None]:
def compare_notes(predicted_notes, true_notes):
    min_length = min(len(predicted_notes), len(true_notes))
    predicted_notes = predicted_notes[:min_length]
    true_notes = true_notes[:min_length]
    scores = [1 if abs(p - t) <= 1 else 0 for p, t in zip(predicted_notes, true_notes)]
    accuracy = sum(scores) / min_length
    return accuracy

similarities = [compare_notes(p, r) for p, r in zip(est_pitches, ref_pitches)]

similarities.sort(reverse=True)

fig = go.Figure(data=[go.Bar(
            x=list(range(len(similarities))), y=similarities,
            text=similarities,
            textposition='auto',
            marker_color='rgb(255, 165, 0)'
        )])

fig.update_layout(
    title=dict(
        text='Comparison of predicted notes and reference notes for each test file',
        font=dict(
            size=22,
            color='black'
        )
    ),
    xaxis=dict(
        title='Test files',
        tickmode='linear',
        tick0=0,
        dtick=len(similarities)//10
    ),
    yaxis=dict(
        title='Accuracy',
        range=[0, 1],
    ),
    autosize=False,
    width=1000,
    height=500,
)

fig.show()


In [None]:
ref_intervals_np = [np.array(x) for x in ref_intervals]
ref_pitches_np = [np.array(x) for x in ref_pitches]
est_intervals_np = [np.array(x) for x in est_intervals]
est_pitches_np = [np.array(x) for x in est_pitches]

p_values = []
r_values = []
f1_values = []

for ref_interval, ref_pitch, est_interval, est_pitch in zip(ref_intervals_np, ref_pitches_np, est_intervals_np, est_pitches_np):
    p, r, f1 = mir_eval_onset_prf(ref_interval, ref_pitch, est_interval, est_pitch)
    p_values.append(p)
    r_values.append(r)
    f1_values.append(f1)

p_mean = np.mean(p_values)
r_mean = np.mean(r_values)
f1_mean = np.mean(f1_values)

print("Mean of p :", p_mean)
print("Mean of r :", r_mean)
print("Mean of f1 :", f1_mean)


Mean of p : 0.2296980704567195
Mean of r : 0.20850351166069522
Mean of f1 : 0.21784706198464035


In [None]:
ref_intervals_np = [np.array(x) for x in ref_intervals]
ref_pitches_np = [np.array(x) for x in ref_pitches]
est_intervals_np = [np.array(x) for x in est_intervals]
est_pitches_np = [np.array(x) for x in est_pitches]

p_values = []
r_values = []
f1_values = []

for ref_interval, ref_pitch, est_interval, est_pitch in zip(ref_intervals_np, ref_pitches_np, est_intervals_np, est_pitches_np):
    p, r, f1 = mir_eval_onset_prf_pitch_only(ref_interval, ref_pitch, est_interval, est_pitch)
    p_values.append(p)
    r_values.append(r)
    f1_values.append(f1)

p_mean = np.mean(p_values)
r_mean = np.mean(r_values)
f1_mean = np.mean(f1_values)

print("Mean of p without considering time intervals :", p_mean)
print("Mean of r without considering time intervals :", r_mean)
print("Mean of f1 without considering time intervals :", f1_mean)


Mean of p without considering time intervals : 0.6701747325151787
Mean of r without considering time intervals : 0.6166848980345934
Mean of f1 without considering time intervals : 0.640361435387222
