In [1]:
import os
import sys

from matplotlib import pyplot as plt
from matplotlib import patches
from matplotlib import colors
import pretty_midi
import pandas as pd
import IPython.display as ipd
import glob
import numpy as np
import muspy
import pypianoroll
import torch
from util.play_midi import play_midi

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from pathlib import Path
from symusic import Score
from torch.utils.data import DataLoader


In [2]:
config = TokenizerConfig()
tokenizer = REMI()
midi_files = list(Path("data/jazz_midi/").glob("**/*.mid"))
print(f"number of midi_files = {len(midi_files)}")
print(f"number of vocab = {len(tokenizer.vocab)}")

number of midi_files = 934
number of vocab = 282


In [3]:
print(len(midi_files)) 
for file in midi_files: # removing corrupted files/ files that cant be fully understood
    try:
        Score(file)
    except:
        midi_files.remove(file)
print(len(midi_files))
for file in midi_files: # on windows this has to be run 2 times, i have no idea why.... very wierd
    try:
        Score(file)
    except:
        midi_files.remove(file)
print(len(midi_files))

934
914
914


In [57]:
song = Score(midi_files[0]) 
tokenized_song = tokenizer.midi_to_tokens(song)

song1 =tokenizer.tokens_to_midi(tokenized_song)
ids = tokenizer._tokens_to_ids(tokenized_song[0])
tokenizer._ids_to_tokens(ids)
tokenized_song

[['Bar_None',
  'Position_0',
  'Pitch_76',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_4',
  'Pitch_78',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_8',
  'Pitch_74',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_12',
  'Pitch_71',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_20',
  'Pitch_73',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_24',
  'Pitch_69',
  'Velocity_95',
  'Duration_0.4.8',
  'Bar_None',
  'Position_0',
  'Pitch_64',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_4',
  'Pitch_66',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_8',
  'Pitch_62',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_12',
  'Pitch_59',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_20',
  'Pitch_61',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_24',
  'Pitch_57',
  'Velocity_95',
  'Duration_0.4.8',
  'Bar_None',
  'Position_0',
  'Pitch_52',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_4',
  'Pitch_54',
  'Velocity_95',
  'Duration_0.4.8',
  'Position_8'

In [91]:
# tokenizer.learn_bpe(vocab_size=500, files_paths=midi_files)







In [109]:
# len(tokenizer.vocab)
# tokenizer["PAD_None"]
# tokenizer["BOS_None"]

# import torchtext.transforms as T

# text_transform = T.Sequential(
#     TransformerTokenizer(tokenizer),  # Tokenize
#     T.VocabTransform(tokenizer_vocab),  # Conver to vocab IDs
#     T.Truncate(max_input_length - 2),  # Cut to max length
#     T.AddToken(token=tokenizer_vocab["[CLS]"], begin=True),  # BOS token
#     T.AddToken(token=tokenizer_vocab["[SEP]"], begin=False),  # EOS token
#     T.ToTensor(padding_value=tokenizer_vocab["[PAD]"]),  # Convert to tensor and pad
# )

1

In [18]:
from transformers import DataCollatorForLanguageModeling
batch_size = 256

dataset = DatasetTok( # seq_len = start + seq_len + end
    files_paths=midi_files,
    min_seq_len=52, # 52
    max_seq_len=batch_size-1, # 514
    tokenizer=tokenizer
)

collator = DataCollator(
    tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True, 
)

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

data_loader = DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collator,)

Loading data: data/jazz_midi: 100%|██████████| 914/914 [00:40<00:00, 22.85it/s]


In [19]:
for batch in data_loader:
    x = batch
    break
print(type(x))


<class 'dict'>


In [20]:
full_inputs = BatchEncoding()
full_inputs.update(x)
full_inputs.to(device)
len(full_inputs["input_ids"][0])

256

In [9]:
y = tokenizer._ids_to_tokens(x["input_ids"][0].tolist())


In [34]:
from transformers import AutoConfig, GPT2LMHeadModel
from transformers.tokenization_utils_base import BatchEncoding

device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [46]:
context_length = 256
n_layer = 2
n_head = 4
n_emb = 64 # 512

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_positions=context_length,
    n_layer=n_layer,
    n_head=n_head,
    pad_token_id=tokenizer["PAD_None"],
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    n_embd=n_emb
)


In [47]:
model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size} parameters")
model.to(device)

import gc

# model.cpu()
# del model
# gc.collect()
# torch.cuda.empty_cache()

# model.to(device)

GPT-2 size: 134528 parameters


In [65]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 10
for epoch in range(epochs):
    total_loss = []
    model.train()
    for batch in data_loader:
        model.zero_grad()
        full_inputs = BatchEncoding()
        full_inputs.update(batch)
        full_inputs.to(device)
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        mask = batch["attention_mask"].to(device)
        outputs = model(**full_inputs)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss.append(loss)
    print(f"Epock = {epoch+1} \nloss = {loss}")


Epock = 1 
loss = 1.6729111671447754
Epock = 2 
loss = 1.65283203125
Epock = 3 
loss = 1.6504460573196411
Epock = 4 
loss = 1.6400920152664185
Epock = 5 
loss = 1.6259053945541382
Epock = 6 
loss = 1.6089879274368286
Epock = 7 
loss = 1.596057653427124
Epock = 8 
loss = 1.5919075012207031
Epock = 9 
loss = 1.589436650276184
Epock = 10 
loss = 1.5815975666046143


In [70]:
model.eval()
with torch.no_grad():
    for batch in data_loader:
        full_inputs = BatchEncoding()
        full_inputs.update(batch)
        full_inputs.to(device)
        inputs = batch["input_ids"]
        labels = batch["labels"].to(device)
        outputs = model(**full_inputs)
        break
# print(outputs)

In [74]:
batch_size = 256
output_tokens = torch.tensor([[0]*context_length]*batch_size)
for i in range(len(outputs["logits"])):
    output_tokens[i] = torch.argmax(outputs["logits"][i], dim=-1)

tokens_output = tokenizer._ids_to_tokens(output_tokens[0].tolist())
print(len(inputs[0]))
print(len(output_tokens[0].tolist()))

256
256


In [58]:
midi_song_output = tokenizer.tokens_to_midi([tokens_output])

In [59]:
midi_song_output

Score(ttype=Tick, tpq=8, begin=0, end=248, tracks=1, notes=49, time_sig=1, key_sig=0, markers=0, lyrics=0)

In [63]:
output_file_path = Path("data/output/", "decoded_midi.mid")
midi_song_output.dump_midi(output_file_path)

In [64]:
play_midi(output_file_path)