In [None]:
from data_utils import read_abc
from model import get_model
from train import get_training_files
from tqdm import tqdm
from pathlib import Path
import torch
import youtokentome as yttm
from argparse import ArgumentParser

In [None]:
def predict_notes(model, tokenizer, keys, notes):
    keys_tokens = tokenizer.encode(keys)
    notes_tokens = tokenizer.encode(notes)
    # print(f"keys: {len(keys_tokens)}")
    print(f"notes: {(notes)}")

    # TODO fix max length of transformer
    if len(keys_tokens) + len(notes_tokens) > 510:
        notes_tokens = notes_tokens[len(notes_tokens) - len(keys_tokens) - 510:]

    context_tokens = [2] + keys_tokens + notes_tokens + [3]

    context_tokens = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)

    if torch.cuda.is_available():
        context_tokens = context_tokens.cuda()
    
    bad_words_ids = []
    bad_words = ["x8 | "]
    for w in bad_words:
        bad_words_ids.append(tokenizer.encode(bad_words)[0])

    gen_tokens = model.generate(input_ids=context_tokens, 
                                max_length=320, 
                                min_length=32,
                                early_stopping=False,
                                num_beams=1,
                                do_sample=False,
                                no_repeat_ngram_size=0,
                                repetition_penalty=1.0,
                                bos_token_id=2, 
                                eos_token_id=3,
                                pad_token_id=0,
                                )
                                
    gen_tokens = gen_tokens[0].tolist()

    notes = tokenizer.decode(gen_tokens, ignore_ids=[0,1,2,3])[0]
    notes = notes.replace(" ", "").replace("|", "|\n")
    
    return notes

def predict(model, tokenizer, text_path, output_dir):
    keys, notes = read_abc(text_path)
    new_path = output_dir.joinpath(text_path.name)

    print(f"keys: {keys}")
    print(f"notes: {notes}")

    predicted_tokens = predict_notes(model, tokenizer, keys, notes)

    with open(text_path) as f:
        abc_text = f.read()

    with open(new_path, "w") as f:
        f.write(abc_text + predicted_tokens)

    return new_path
        


In [None]:
datapath = "cleaned_data"
checkpoint = "ABCModel/checkpoint-10000/pytorch_model.bin"
tokenizer = "abc.yttm"
output_dir = "predict_abc"

test_paths = get_training_files(datapath)[:10]

tokenizer = yttm.BPE(tokenizer)
model = get_model(tokenizer.vocab_size())
checkpoint = torch.load(checkpoint, map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint)
if torch.cuda.is_available():
    model = model.cuda()

output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)

In [None]:
for p in test_paths:
    abc_path = predict(model, tokenizer, p, output_dir)
    print(abc_path)