In [4]:
from transformers import GenerationConfig
import io

from src import utils
from src.model import MidiGPT2
import torch
import mido
import rtmidi

midiout = rtmidi.MidiOut()
available_ports = midiout.get_ports()
print(available_ports)
checkpoint_path = "../checkpoints/legacy/pertok-1-extend-2/gpt2-midi-epoch=31-train_loss=0.1332.ckpt"


tokenizer = utils.get_tokenizer()
model = MidiGPT2.load_from_checkpoint(checkpoint_path, tokenizer=tokenizer)
model.eval()
model.freeze()

model.model.save_pretrained("midi-gpt2/")
gen_config = GenerationConfig(
    max_new_tokens=2048,
    do_sample=True,
    temperature=1.0,
    top_k=5,
    top_p=0.95,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer["EOS_None"],
)

# Save it alongside the model
gen_config.save_pretrained("midi-gpt2/")

# model.model = torch.compile(model.model)

port = mido.open_output('GarageBand Virtual In')
with torch.no_grad():
    # input_ids = torch.tensor([x.ids for x in tokenizer.encode("./data/test/test2.mid")], dtype=torch.long).to(model.device)
    input_ids = torch.tensor([[tokenizer["BOS_None"]]], dtype=torch.long).to(model.device)

    generated = model.model.generate(
        input_ids=input_ids,
        max_new_tokens=2048,
        do_sample=True,
        temperature=1,
        top_k=5,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer["EOS_None"],
    )
    generated_ids = generated.tolist()
    print(len(input_ids[0]))
    midi_bytes = tokenizer.decode(generated_ids).dumps_midi()
    midi_file = io.BytesIO(midi_bytes)
    midi = mido.MidiFile(file=midi_file)
    for msg in midi.play():
        print(msg)
        port.send(msg)






['Roland Digital Piano', 'IAC Driver Bus 1', 'Logic Pro Virtual In', 'GarageBand Virtual In']
1
program_change channel=0 program=0 time=0
note_on channel=0 note=80 velocity=63 time=2.2318649999999995
note_off channel=0 note=80 velocity=63 time=0.18598874999999998
note_on channel=0 note=78 velocity=63 time=0
note_off channel=0 note=78 velocity=63 time=0.18598874999999998
note_on channel=0 note=78 velocity=79 time=0.37197749999999996
note_on channel=0 note=71 velocity=79 time=0
note_on channel=0 note=75 velocity=79 time=0
note_on channel=0 note=47 velocity=63 time=0
note_off channel=0 note=47 velocity=63 time=0.37197749999999996
note_on channel=0 note=54 velocity=79 time=0
note_off channel=0 note=54 velocity=79 time=0.37197749999999996
note_on channel=0 note=59 velocity=79 time=0
note_off channel=0 note=78 velocity=79 time=0.37197749999999996
note_off channel=0 note=71 velocity=79 time=0
note_off channel=0 note=75 velocity=79 time=0
note_on channel=0 note=82 velocity=79 time=0
note_off c