# MIDI Inference

- song level inference
- clip level inference

In [2]:
import os
from pathlib import Path
import torchaudio
import pandas as pd
import torch
from tqdm import tqdm
import pickle
import pretty_midi
import sys
sys.path.append("..")
from midi_cls.midi_helper.remi.midi2event import analyzer, corpus, event
from midi_cls.midi_helper.magenta.processor import _divide_note, _control_preprocess, _note_preprocess, _make_time_sift_events, _snote2events

In [28]:
os.listdir("../dataset/")

['split',
 'remi_midi_test',
 'magenta_midi_test',
 'resample22050',
 'magenta_midi',
 'remi_midi',
 'remi_key_midi',
 'sample_data',
 'matlab_feature']

In [29]:
test = "Q2_OCqMDeD6Fmc_3.pt"
temp = torch.load(os.path.join("../dataset/magenta_midi",test))

In [30]:
temp2 = torch.load(os.path.join("../dataset/magenta_midi_test",test))

In [22]:
path_data_root = "../midi_cls/midi_helper/remi/"
path_dictionary = os.path.join(path_data_root, 'dictionary.pkl')
midi_dictionary = pickle.load(open(path_dictionary, "rb"))
int_to_event= midi_dictionary[0]

In [36]:
int_to_event

{0: 'Bar_None',
 1: 'Beat_0',
 2: 'Beat_1',
 3: 'Beat_10',
 4: 'Beat_11',
 5: 'Beat_12',
 6: 'Beat_13',
 7: 'Beat_14',
 8: 'Beat_15',
 9: 'Beat_2',
 10: 'Beat_3',
 11: 'Beat_4',
 12: 'Beat_5',
 13: 'Beat_6',
 14: 'Beat_7',
 15: 'Beat_8',
 16: 'Beat_9',
 17: 'Chord_A#_+',
 18: 'Chord_A#_/o7',
 19: 'Chord_A#_7',
 20: 'Chord_A#_M',
 21: 'Chord_A#_M7',
 22: 'Chord_A#_m',
 23: 'Chord_A#_m7',
 24: 'Chord_A#_o',
 25: 'Chord_A#_o7',
 26: 'Chord_A#_sus2',
 27: 'Chord_A#_sus4',
 28: 'Chord_A_+',
 29: 'Chord_A_7',
 30: 'Chord_A_M',
 31: 'Chord_A_M7',
 32: 'Chord_A_m',
 33: 'Chord_A_m7',
 34: 'Chord_A_o',
 35: 'Chord_A_sus2',
 36: 'Chord_A_sus4',
 37: 'Chord_B_+',
 38: 'Chord_B_/o7',
 39: 'Chord_B_7',
 40: 'Chord_B_M',
 41: 'Chord_B_M7',
 42: 'Chord_B_m',
 43: 'Chord_B_m7',
 44: 'Chord_B_o',
 45: 'Chord_B_o7',
 46: 'Chord_B_sus2',
 47: 'Chord_B_sus4',
 48: 'Chord_C#_+',
 49: 'Chord_C#_/o7',
 50: 'Chord_C#_7',
 51: 'Chord_C#_M',
 52: 'Chord_C#_M7',
 53: 'Chord_C#_m',
 54: 'Chord_C#_m7',
 55: 'Chord

In [14]:
device = "cuda:0"

In [20]:
def batch_encode(file_path, segments = 3):
    events = []
    notes = []
    mid = pretty_midi.PrettyMIDI(midi_file=file_path)

    for inst in mid.instruments:
        inst_notes = inst.notes
        ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
        notes += _note_preprocess(ctrls, inst_notes)

    dnotes = _divide_note(notes)
    dnotes.sort(key=lambda x: x.time)
    cur_time = 0
    cur_vel = 0
    
    timebin = np.arange(0, dnotes[-1].time+segments,step=segments)
    for idx, i in enumerate(timebin):
        bins = []
        for snote in dnotes:
            if timebin[idx] < snote.time < timebin[idx + 1]:                
                bins += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
                bins += _snote2events(snote=snote, prev_vel=cur_vel)
                # events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
                cur_time = snote.time
                cur_vel = snote.velocity
        if len(bins) == 0:
            pass
        else:
            events.append([e.to_int() for e in bins])
    return events

def _get_predictions(model, midi_path, segments=5):
    quantize_midi = batch_encode(midi_path, segments=segments)
    song_level = []
    for one_segment in quantize_midi:
        torch_midi = torch.LongTensor(one_segment).unsqueeze(0).to(device)
        prediction = model(torch_midi)
        song_level.append(prediction.squeeze(0).detach().cpu().numpy())
    return song_level

In [17]:
test_midi = "../dataset/sample_data/Sakamoto_MerryChristmasMr_Lawrence.mid"
song_level = batch_encode(test_midi, segments=30)

# Audio Inference

- song level inference
- frame level inference

In [21]:
from IPython.display import Audio
test_mp3 = "../dataset/sample_data/Sakamoto_MerryChristmasMr_Lawrence.mp3"

In [22]:
Audio(test_mp3)