In [99]:
import io
from pathlib import Path
from miditok import REMI, TokenizerConfig

import utils
from model import MidiGPT2
import torch
import mido
import rtmidi

midiout = rtmidi.MidiOut()
available_ports = midiout.get_ports()
print(available_ports)
checkpoint_path = "/checkpoints/legacy/pertok-1/pertok-pretrain-final.ckpt"


tokenizer = utils.get_tokenizer()

model = MidiGPT2.load_from_checkpoint(checkpoint_path, tokenizer=tokenizer)
model.eval()
model.freeze()

port = mido.open_output('GarageBand Virtual In')
with torch.no_grad():
    # input_ids = torch.tensor([x.ids for x in tokenizer.encode("./data/test/test3.mid")], dtype=torch.long).to(model.device)
    input_ids = torch.tensor([[tokenizer["BOS_None"]]], dtype=torch.long).to(model.device)

    generated = model.model.generate(
        input_ids=input_ids,
        max_new_tokens=512,
        do_sample=True,
        temperature=1,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer["EOS_None"],
    )
    generated_ids = generated.tolist()
    print(len(input_ids[0]))
    midi_bytes = tokenizer.decode(generated_ids).dumps_midi()
    midi_file = io.BytesIO(midi_bytes)
    midi = mido.MidiFile(file=midi_file)
    for msg in midi.play():
        port.send(msg)






['Roland Digital Piano', 'IAC Driver Bus 1', 'GarageBand Virtual In', 'Logic Pro Virtual In']
1
