In [None]:
from ariautils.midi import MidiDict
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch


device = "mps" if torch.backends.mps.is_available() else "cpu"
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to(device)


PROMPT_MIDI_LOAD_PATH = "../data/test/noc.mid"
CONTINUATION_MIDI_SAVE_PATH = "../data/test/continuation.midi"

model = AutoModelForCausalLM.from_pretrained(
    "loubb/aria-medium-base",
    trust_remote_code=True,
).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    "loubb/aria-medium-base",
    trust_remote_code=True,
    add_eos_token=True,
    add_dim_token=False
)

midi_dict = MidiDict.from_midi(PROMPT_MIDI_LOAD_PATH)
tokens = tokenizer.tokenize(midi_dict, add_eos_token=False, add_dim_token=False)
token_ids = tokenizer._tokenizer.encode(tokens)
prompt_input_ids = torch.tensor([token_ids], device=device)

continuation = model.generate(
    prompt_input_ids.to(device),
    max_length=1024,
    do_sample=True,
    temperature=0.97,
    top_p=0.95,
    use_cache=True,
)

# decode back into MIDI
midi_dict = tokenizer.decode(continuation[0].tolist())
print(tokenizer._tokenizer.decode(continuation[0].tolist()))
midi_dict.to_midi().save(CONTINUATION_MIDI_SAVE_PATH)

print(f"Saved continuation to {CONTINUATION_MIDI_SAVE_PATH}")
