In [41]:
from pathlib import Path
import numpy as np
import torch.nn as nn
import torch
from dataclasses import dataclass
from pathlib import Path
import os

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DataCollator, DatasetMIDI
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig
from transformers import GPT2Config, GPT2LMHeadModel
from tqdm import tqdm


def train(model, optim, dataloader, ctriterion, epochs):
    model.train()
    for epoch in range(epochs):
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=True)
        epoch_loss = 0

        for data in progress_bar:
            optim.zero_grad()

            input_ids = data["input_ids"]
            mask = data["attention_mask"]
            out = model(input_ids, attention_mask=mask)

            shift_logits = out.logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = ctriterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            loss.backward()
            optim.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss / len(dataloader):.4f}")

In [None]:
@dataclass
class ModelConfig:
    vocab_size: int = -1
    d_model: int = 128
    n_head: int = 2
    n_layers: int = 2
    batch_size: int = 32
    max_seq_length: int = 1024

config = ModelConfig()
tokenizer = REMI(TokenizerConfig(use_tempos=True, use_pitchdrum_tokens=False))
config.vocab_size = tokenizer.vocab_size
bos_token = tokenizer["BOS_None"]
eos_token = tokenizer["EOS_None"]
pad_token = tokenizer["PAD_None"]

gpt_config = GPT2Config(
    vocab_size=config.vocab_size,
    n_positions=config.max_seq_length,
    n_embd=config.d_model,
    n_layer=config.n_layers,
    n_head=config.n_head,
    bos_token_id=bos_token,
    eos_token_id=eos_token,
)
model = GPT2LMHeadModel(gpt_config)

# dataset setup
midi_dir = Path("pop1k7/midi_analyzed")
dataset = DatasetMIDI(
    files_paths=list(midi_dir.glob("**/*.mid")),
    tokenizer=tokenizer,
    max_seq_len=config.max_seq_length,
    bos_token_id=bos_token,
    eos_token_id=eos_token,
)

collator = DataCollator(tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=config.batch_size, collate_fn=collator)

# training setup
optim = torch.optim.Adam(model.parameters())
ctriterion = nn.CrossEntropyLoss(ignore_index=pad_token)
# train(model, optim=optim, dataloader=dataloader, ctriterion=ctriterion, epochs=1)

In [62]:
from midi_player import MIDIPlayer
test_track = "/Users/stevenkao/workspace/music-hw-3/pop1k7/midi_analyzed/src_001/0.mid" 
MIDIPlayer(test_track, 400)  

In [64]:
z = tokenizer.encode(test_track)
tokens = torch.tensor(z[0].ids)
tokens = tokens.view(1, -1)
score = tokenizer.decode(tokens)
score.dump_midi("sample.mid")
MIDIPlayer("/Users/stevenkao/workspace/music-hw-3/sample.mid", 400)  

# x = dataset[5]
# tokens = x['input_ids'].view(1, -1)
# score = tokenizer.decode(tokens)
# score.dump_midi("sample.mid")
# MIDIPlayer("/Users/stevenkao/workspace/music-hw-3/sample.mid", 400)  

In [None]:
# def test(model, tokenizer, n_target_bar = 32, temperature = 1.2, topk = 5, output_path = '', model_path = ''):
os.makedirs('./results', exist_ok=True)

with torch.no_grad():
    model.eval()
    batch_size = 1

    words = []
    # for _ in range(batch_size):

    ws = [tokenizer.vocab['Bar_None']]
    tempo_classes = [v for k, v in tokenizer.vocab.items() if 'Tempo' in k]
    tempo_values = [v for k, v in tokenizer.vocab.items() if 'Tempo Value' in k]

    print(tempo_classes, tempo_values)

    # ws.append(tokenizer.vocab['Position_1/16'])
    # ws.append(np.random.choice(tempo_classes))
    # ws.append(np.random.choice(tempo_values))
    # words.append(ws)

    # generate
    # original_length = len(words[0])
    # initial_flag = 1
    # current_generated_bar = 0
    # print('Start generating')
    # while current_generated_bar < n_target_bar:
    #     print("\r", current_generated_bar, end="")
    #     # input
    #     if initial_flag:
    #         temp_x = np.zeros((batch_size, original_length))
    #         for b in range(batch_size):
    #             for z, t in enumerate(words[b]):
    #                 temp_x[b][z] = t
    #         initial_flag = 0
    #     else:
    #         temp_x_new = np.zeros((batch_size, 1))
    #         for b in range(batch_size):
    #             temp_x_new[b][0] = words[b][-1]
    #         temp_x = np.array([np.append(temp_x[0], temp_x_new[0])])
        
    #     temp_x = torch.Tensor(temp_x).long()
        
        
    #     output_logits = model(temp_x.to(opt.device))
        
    #     # sampling
    #     _logit = output_logits[0, -1].to('cpu').detach().numpy()
    #     word = temperature_sampling(
    #         logits=_logit, 
    #         temperature=temperature,
    #         topk=topk)

    #     words[0].append(word)

    #     if word == tokenizer.vocab['Bar_None']:
    #         current_generated_bar += 1
    
    # utils.write_midi(
    #     words=words[0],
    #     word2event=word2event,
    #     output_path=output_path,
    #     prompt_path=None)

[] []


dict_items([('PAD_None', 0), ('BOS_None', 1), ('EOS_None', 2), ('MASK_None', 3), ('Bar_None', 4), ('Pitch_21', 5), ('Pitch_22', 6), ('Pitch_23', 7), ('Pitch_24', 8), ('Pitch_25', 9), ('Pitch_26', 10), ('Pitch_27', 11), ('Pitch_28', 12), ('Pitch_29', 13), ('Pitch_30', 14), ('Pitch_31', 15), ('Pitch_32', 16), ('Pitch_33', 17), ('Pitch_34', 18), ('Pitch_35', 19), ('Pitch_36', 20), ('Pitch_37', 21), ('Pitch_38', 22), ('Pitch_39', 23), ('Pitch_40', 24), ('Pitch_41', 25), ('Pitch_42', 26), ('Pitch_43', 27), ('Pitch_44', 28), ('Pitch_45', 29), ('Pitch_46', 30), ('Pitch_47', 31), ('Pitch_48', 32), ('Pitch_49', 33), ('Pitch_50', 34), ('Pitch_51', 35), ('Pitch_52', 36), ('Pitch_53', 37), ('Pitch_54', 38), ('Pitch_55', 39), ('Pitch_56', 40), ('Pitch_57', 41), ('Pitch_58', 42), ('Pitch_59', 43), ('Pitch_60', 44), ('Pitch_61', 45), ('Pitch_62', 46), ('Pitch_63', 47), ('Pitch_64', 48), ('Pitch_65', 49), ('Pitch_66', 50), ('Pitch_67', 51), ('Pitch_68', 52), ('Pitch_69', 53), ('Pitch_70', 54), ('Pitch