In [2]:
from pathlib import Path
import numpy as np
import torch.nn as nn
import torch
from dataclasses import dataclass
from pathlib import Path
import os
import miditok
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DataCollator, DatasetMIDI
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig
from transformers import GPT2Config, GPT2LMHeadModel
from tqdm import tqdm
from midi_player import MIDIPlayer
from main import ModelConfig, checkpoint_load

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
tkn_config = TokenizerConfig(
    use_tempos=True,
    use_pitchdrum_tokens=False,
    beat_res={(0, 4): 16, (4, 12): 8},
)
tokenizer = REMI(tkn_config)
config = ModelConfig(
    device="cpu",
    vocab_size=tokenizer.vocab_size,
    n_embd=512,
    n_head=8,
    n_layers=12,
    batch_size=8,
    max_seq_length=1024,
)

gpt_config = GPT2Config(
    vocab_size=config.vocab_size,
    n_positions=config.max_seq_length,
    n_embd=config.n_embd,
    n_layer=config.n_layers,
    n_head=config.n_head,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
model = GPT2LMHeadModel(gpt_config)
model.generation_config.pad_token_id = tokenizer["PAD_None"]
model.to(config.device)
checkpoint_load("checkpoints/cp_82.pt", model, config)
model.eval()

print(f"model parameters: {sum(p.numel() for p in model.parameters()):,}")

load checkpoint checkpoints/cp_82.pt
model parameters: 38,532,608


In [None]:
sample_size = 2
tokens = torch.tensor(
    [[tokenizer.vocab["Bar_None"]]] * sample_size,
    device=config.device,
) # (batch_n, seq_n)

with tqdm(total=config.max_sample_length, desc="Generating tokens", unit="token") as pbar:
    while tokens.size(1) < config.max_sample_length:
        # Generate one token at a time
        input_context = tokens[:, -(config.max_seq_length - 1):]
        output = model.generate(
            input_context,
            attention_mask=torch.ones(input_context.shape, device=config.device),
            max_length=input_context.size(1) + 1,
            do_sample=True,
        )
        new_token = output[:, -1:]
        tokens = torch.cat((tokens, new_token), dim=1)
        pbar.update(1)

        # check if all batch has ended
        all_end = (tokens == tokenizer["EOS_None"]).any(dim=1).all() 
        if all_end:
            break

# print("sample token len", tokens.size(1))
# # tokens = tokens.cpu()
# # tokenizer.decode(tokens)
# # score = tokenizer.decode(tokens)
# # score.dump_midi(save_path)
# # wandb.save(save_path)

Generating tokens:   0%|          | 1/3072 [00:00<01:08, 44.77token/s]

tensor(False)



