In [None]:
from torch import cuda, load, device, set_grad_enabled, max, sum, cat
from preprocessing.nn_dataset import get_test_set_for_eval_classic
from transformers import GPT2LMHeadModel, AutoTokenizer
import note_seq

"""
File containing functions to quantitatively evaluate trained models
"""

def eval_on_test_set(load_path, model, criterion, set='test', notes_only=False, tokenizer=None):
    model = model

    try:
        if cuda.is_available():
            model.load_state_dict(load(load_path)['model_state_dict'])
        else:
            model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
        print("loaded params from", load_path)
    except FileNotFoundError:
        raise FileNotFoundError(f'No file located at {load_path}, could not load parameters')
    print(model)

    if cuda.is_available():
        model.cuda()

    model.eval()
    criterion = criterion

    count = 0
    batch_count = 0
    loss_epoch = 0
    running_accuracy = 0.0
    running_batch_count = 0
    print_loss_batch = 0  # Reset on print
    print_acc_batch = 0  # Reset on print
    pr_interval = 1

    for x, y, psx, i, c in get_test_set_for_eval_classic(set):
        model.zero_grad()

        train_emb = False

        Y = y

        with set_grad_enabled(False):
            # Tokenize input sequence using the specified tokenizer
            input_ids = tokenizer.encode(x, return_tensors="pt").cuda()

            # Generate predictions from the model
            y_hat = model.generate(
                input_ids,
                max_length=2048,
                do_sample=True,
                temperature=0.75,
                eos_token_id=tokenizer.encode("TRACK_END")[0]
            )

            # Decode the generated tokens into a token sequence
            token_sequence = tokenizer.decode(y_hat[0])

            # Convert the token sequence into a NoteSequence
            generated_note_sequence = token_sequence_to_note_sequence(token_sequence)

            if notes_only:
                # Extract only note-related information if specified
                # (you may need to modify this part based on your specific task)
                # ...
                pass
            else:
                # Process the entire generated NoteSequence
                # (you may need to modify this part based on your specific task)
                # ...
                pass

            # Calculate the loss using the criterion
            loss = criterion(generated_note_sequence, Y)

        loss_epoch += loss.item()
        print_loss_batch += loss.item()

        if notes_only:
            # Calculate accuracy for note-related information if specified
            # (you may need to modify this part based on your specific task)
            # ...
            pass
        else:
            # Calculate accuracy for the entire generated sequence
            # (you may need to modify this part based on your specific task)
            # ...
            pass

        count += 1
        batch_count += len(x)
        running_batch_count += len(x)

        # print loss for recent set of batches
        if count % pr_interval == 0:
            ave_loss = print_loss_batch / pr_interval
            ave_acc = 100 * print_acc_batch.float() / running_batch_count
            print_acc_batch = 0
            running_batch_count = 0
            print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc))
            print_loss_batch = 0

    # calculate loss and accuracy for phase
    ave_loss_epoch = loss_epoch / count
    epoch_acc = 100 * running_accuracy.float() / batch_count
    print('\tfinished %s phase loss: %.3f, acc: %.3f' % ('eval', ave_loss_epoch, epoch_acc))



In [None]:
# Assume this function is present in your second snippet
def token_sequence_to_note_sequence(token_sequence):
    # Implement this function based on your specific token-to-note conversion logic
    # This function should convert the token sequence into a NoteSequence
    # ...



In [None]:
# Use the function
eval_on_test_set('path_to_model_checkpoint.pth', your_model_instance, your_criterion_instance, set='test', notes_only=False, tokenizer=your_tokenizer_instance)
