In [1]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import io
import torch
import pickle
from gpt_mini.model import GPT
import gpt_mini.midi_encoder as midi_encoder
from symusic import Score, Track, Note
from gpt_mini.config import DEFAULT_DEVICE, CONFIG

In [3]:
block_size = 36
model_config = GPT.get_default_config()
model_config.model_type = None # 'gpt-nano'
model_config.n_layer = CONFIG["model"]["n_layer"]
model_config.n_head = CONFIG["model"]["n_head"]
model_config.n_embd = CONFIG["model"]["n_embed"]
# model_config.vocab_size = 50257 # 65535        # max number of vocabulary
model_config.vocab_size = CONFIG["model"]["vocab_size"]
# model_config.block_size = 256                  # input context length
model_config.block_size = CONFIG["model"]["block_size"]
model = GPT(model_config)
model.load_state_dict(
    torch.load("./checkpoints/gpt_2500.pt",
               map_location=torch.device(DEFAULT_DEVICE))
)
model.eval()

tokenizer = None
f = open(CONFIG["tokenizer"]["model"], 'rb')
tokenizer = pickle.load(f)
f.close()

number of parameters: 2.51M


In [4]:
from pathlib import Path

def generate(model, prompt: str, num_samples=5, steps=64, do_sample=True):
    # token_ids = sp.encode_as_ids(prompt)
    tokens = tokenizer.encode(prompt)
    tokens = tokens[0].ids
    # x = torch.tensor([token_ids], dtype=torch.long)
    x = torch.tensor(tokens, dtype=torch.long)

    # we'll process all desired num_samples in a batch, so expand out the
    # batch dim
    x = x.expand(num_samples, -1)

    # forward the model `steps` times to get samples, in a batch
    y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)

    i = 0
    arr = list(y[i].detach().cpu().numpy())
    arr = [int(token_id) for token_id in arr if token_id != 0]

    out = tokenizer.decode([arr])

    out.tracks[0].is_drum=True
    out.dump_midi(Path("./output", "model_output.mid"))

    return out

In [5]:
midi = generate(model, prompt=f"./input/1.mid")
display(midi.tracks)

symusic.core.TrackTickList([Track(ttype=Tick, program=0, is_drum=true, name=Acoustic Grand Piano, notes=24, lyrics=0)])