In [37]:
import pickle
from dataclasses import dataclass
from pathlib import Path

import miditok
import numpy as np
import torch
import torch.nn as nn
from midi_player import MIDIPlayer
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DataCollator, DatasetMIDI
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

## Check original dataset format

In [None]:
# 'Beat_0': 1,
event2idx, idx2event = pickle.load(open('pickles/remi_vocab.pkl', 'rb'))
list(event2idx.keys())

['Bar_None',
 'Beat_0',
 'Beat_1',
 'Beat_10',
 'Beat_11',
 'Beat_12',
 'Beat_13',
 'Beat_14',
 'Beat_15',
 'Beat_2',
 'Beat_3',
 'Beat_4',
 'Beat_5',
 'Beat_6',
 'Beat_7',
 'Beat_8',
 'Beat_9',
 'Chord_A#_+',
 'Chord_A#_/o7',
 'Chord_A#_7',
 'Chord_A#_M',
 'Chord_A#_M7',
 'Chord_A#_m',
 'Chord_A#_m7',
 'Chord_A#_o',
 'Chord_A#_o7',
 'Chord_A#_sus2',
 'Chord_A#_sus4',
 'Chord_A_+',
 'Chord_A_/o7',
 'Chord_A_7',
 'Chord_A_M',
 'Chord_A_M7',
 'Chord_A_m',
 'Chord_A_m7',
 'Chord_A_o',
 'Chord_A_o7',
 'Chord_A_sus2',
 'Chord_A_sus4',
 'Chord_B_+',
 'Chord_B_/o7',
 'Chord_B_7',
 'Chord_B_M',
 'Chord_B_M7',
 'Chord_B_m',
 'Chord_B_m7',
 'Chord_B_o',
 'Chord_B_o7',
 'Chord_B_sus2',
 'Chord_B_sus4',
 'Chord_C#_+',
 'Chord_C#_/o7',
 'Chord_C#_7',
 'Chord_C#_M',
 'Chord_C#_M7',
 'Chord_C#_m',
 'Chord_C#_m7',
 'Chord_C#_o',
 'Chord_C#_o7',
 'Chord_C#_sus2',
 'Chord_C#_sus4',
 'Chord_C_+',
 'Chord_C_/o7',
 'Chord_C_7',
 'Chord_C_M',
 'Chord_C_M7',
 'Chord_C_m',
 'Chord_C_m7',
 'Chord_C_o',
 'Chord

In [21]:
len(list(event2idx.keys()))

332

In [10]:
len(pickle.load(open('pickles/train_pieces.pkl', 'rb')))
pickle.load(open('pickles/train_pieces.pkl', 'rb'))[0:10]

['1398.pkl',
 '954.pkl',
 '848.pkl',
 '1574.pkl',
 '128.pkl',
 '802.pkl',
 '1255.pkl',
 '1682.pkl',
 '330.pkl',
 '443.pkl']

In [11]:
ids, tokens = pickle.load(open('remi_dataset/1.pkl', 'rb'))

# [idx2event[_] for _ in ids if _ in idx2event]
# ids[-10:-1]
ids[0:10]
# tokens

[0, 54, 105, 143, 188, 247, 302, 361, 414, 465]

In [12]:
tokens

[{'name': 'Bar', 'value': None},
 {'name': 'Beat', 'value': 0},
 {'name': 'Chord', 'value': 'N_N'},
 {'name': 'Tempo', 'value': np.int64(119)},
 {'name': 'Note_Pitch', 'value': 45},
 {'name': 'Note_Velocity', 'value': np.int64(60)},
 {'name': 'Note_Duration', 'value': 1440},
 {'name': 'Beat', 'value': 2},
 {'name': 'Note_Pitch', 'value': 52},
 {'name': 'Note_Velocity', 'value': np.int64(62)},
 {'name': 'Note_Duration', 'value': 720},
 {'name': 'Beat', 'value': 4},
 {'name': 'Tempo', 'value': np.int64(65)},
 {'name': 'Note_Pitch', 'value': 56},
 {'name': 'Note_Velocity', 'value': np.int64(58)},
 {'name': 'Note_Duration', 'value': 600},
 {'name': 'Beat', 'value': 5},
 {'name': 'Note_Pitch', 'value': 61},
 {'name': 'Note_Velocity', 'value': np.int64(64)},
 {'name': 'Note_Duration', 'value': 240},
 {'name': 'Beat', 'value': 7},
 {'name': 'Note_Pitch', 'value': 64},
 {'name': 'Note_Velocity', 'value': np.int64(70)},
 {'name': 'Note_Duration', 'value': 600},
 {'name': 'Beat', 'value': 8},
 {

## Tokenize Jazz dataset

There are some event string format difference between the `remi_dataset` and event string produced from the miditok

- 'Beat' vs 'Position'
- 'Note_Pitch' vs 'Pitch'
- 'Chord_B:7aug' vs 'Chord_A#_+'

In [14]:
tokenizer = REMI(
    TokenizerConfig(
        use_tempos=True, 
        use_pitchdrum_tokens=False, 
        chord_tokens_with_root_note=True,
        use_chords=True,
        beat_res={(0, 4): 16, (4, 12): 8},
    )
)

event2idx = tokenizer.vocab
idx2event = {v: k for k, v in event2idx.items()}

with open('jazz/meta/remi_vocab.pkl', 'wb') as handle:
    pickle.dump((event2idx, idx2event), handle, protocol=pickle.HIGHEST_PROTOCOL)

tokenizer.save("jazz/meta/remi_tokenizer.json")

In [18]:
len(tokenizer.vocab)

517

In [15]:
tokenizer.vocab

{'PAD_None': 0,
 'BOS_None': 1,
 'EOS_None': 2,
 'MASK_None': 3,
 'Bar_None': 4,
 'Pitch_21': 5,
 'Pitch_22': 6,
 'Pitch_23': 7,
 'Pitch_24': 8,
 'Pitch_25': 9,
 'Pitch_26': 10,
 'Pitch_27': 11,
 'Pitch_28': 12,
 'Pitch_29': 13,
 'Pitch_30': 14,
 'Pitch_31': 15,
 'Pitch_32': 16,
 'Pitch_33': 17,
 'Pitch_34': 18,
 'Pitch_35': 19,
 'Pitch_36': 20,
 'Pitch_37': 21,
 'Pitch_38': 22,
 'Pitch_39': 23,
 'Pitch_40': 24,
 'Pitch_41': 25,
 'Pitch_42': 26,
 'Pitch_43': 27,
 'Pitch_44': 28,
 'Pitch_45': 29,
 'Pitch_46': 30,
 'Pitch_47': 31,
 'Pitch_48': 32,
 'Pitch_49': 33,
 'Pitch_50': 34,
 'Pitch_51': 35,
 'Pitch_52': 36,
 'Pitch_53': 37,
 'Pitch_54': 38,
 'Pitch_55': 39,
 'Pitch_56': 40,
 'Pitch_57': 41,
 'Pitch_58': 42,
 'Pitch_59': 43,
 'Pitch_60': 44,
 'Pitch_61': 45,
 'Pitch_62': 46,
 'Pitch_63': 47,
 'Pitch_64': 48,
 'Pitch_65': 49,
 'Pitch_66': 50,
 'Pitch_67': 51,
 'Pitch_68': 52,
 'Pitch_69': 53,
 'Pitch_70': 54,
 'Pitch_71': 55,
 'Pitch_72': 56,
 'Pitch_73': 57,
 'Pitch_74': 58,
 'Pitc

In [35]:
def token_obj(token_str):
    name, value = token_str.split("_")
    # try:
    #     value = float(value)
    # except (ValueError, TypeError):
    #     pass

    # to the same chord format as remi_dataset codes
    # Chord_G#_m7
    # Chord_B:9maj

    # if "Pitch" in name:
    #     name = name.replace("Pitch", "Note_Pitch")

    # if "Velocity" in name:
    #     name = name.replace("Velocity", "Note_Velocity")

    # if "Duration" in name:
    #     name = name.replace("Duration", "Note_Duration")

    # if "Chord" in name:
    #     name = name.replace(":", "_")

    # if "Position" in name:
    #     name = name.replace("Position", "Beat")

    return { "name": name, "value": value }


def to_pickle(src, dest):
    z = tokenizer.encode(src)[0]
    tokens = [token_obj(_) for _ in z.tokens]
    bar_ids = [i for i, x in enumerate(z.ids) if x == tokenizer.vocab["Bar_None"]]

    with open(dest, 'wb') as handle:
        pickle.dump((bar_ids, tokens), handle, protocol=pickle.HIGHEST_PROTOCOL)

# to_pickle("jazz/midi/0.mid", "jazz/pickles/0.pkl")

In [36]:
for name in tqdm(os.listdir("jazz/midi")):
    dest = os.path.join("jazz/pickles", name.replace(".mid", ".pkl"))
    src = os.path.join("jazz/midi", name)
    to_pickle(src, dest)

100%|██████████| 1994/1994 [01:01<00:00, 32.30it/s]


In [151]:
from sklearn.model_selection import train_test_split

all_songs = list(os.listdir("jazz/pickles"))
train, valid = train_test_split(all_songs, test_size=0.1)
valid, test = train_test_split(valid, test_size=0.5)

with open("jazz/meta/train.pieces.pkl", 'wb') as handle:
    pickle.dump(train, handle)

with open("jazz/meta/valid.pieces.pkl", 'wb') as handle:
    pickle.dump(valid, handle)

with open("jazz/meta/test.pieces.pkl", 'wb') as handle:
    pickle.dump(test, handle)

In [153]:
pickle.load(open('jazz/meta/train.pieces.pkl', 'rb'))[0:10]

['865.pkl',
 '146.pkl',
 '1388.pkl',
 '428.pkl',
 '1024.pkl',
 '1524.pkl',
 '802.pkl',
 '1906.pkl',
 '889.pkl',
 '233.pkl']

In [34]:
ids, tks = pickle.load(open('jazz/pickles/5.pkl', 'rb'))
tks[0:20], ids[0:10]
# [_ for _ in tks if "Chord" in _["name"]]

([{'name': 'Bar', 'value': 'None'},
  {'name': 'Beat', 'value': '0'},
  {'name': 'Tempo', 'value': '121.29'},
  {'name': 'Beat', 'value': '16'},
  {'name': 'Tempo', 'value': '46.77'},
  {'name': 'Beat', 'value': '21'},
  {'name': 'Note_Pitch', 'value': '41'},
  {'name': 'Note_Velocity', 'value': '59'},
  {'name': 'Note_Duration', 'value': '0.3.16'},
  {'name': 'Beat', 'value': '23'},
  {'name': 'Note_Pitch', 'value': '48'},
  {'name': 'Note_Velocity', 'value': '59'},
  {'name': 'Note_Duration', 'value': '0.1.16'},
  {'name': 'Beat', 'value': '24'},
  {'name': 'Note_Pitch', 'value': '53'},
  {'name': 'Note_Velocity', 'value': '63'},
  {'name': 'Note_Duration', 'value': '0.2.16'},
  {'name': 'Beat', 'value': '25'},
  {'name': 'Note_Pitch', 'value': '58'},
  {'name': 'Note_Velocity', 'value': '55'}],
 [0, 134, 282, 508, 668, 860, 1022, 1285, 1521, 1772])