In [1]:
import sys
sys.path.append('..')
import processing
import configs.common as cc
import configs.paths as paths
from models.classifier import Classifier
from train_classifier import get_all_targets
import torch
import json
import os
from train import load_model
from generate import generate
import train
data_root = '/home/s203861/midi-classical-music/np_data/data'
band_folders = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))]
loader = processing.DatasetLoader(data_root)
dataloader = loader.get_dataloader_full()

  import pkg_resources
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import sys
sys.path.append('..')
import torch
from train import load_model, new_model
import configs.common as cc

no_meta = True
type = 'mamba'
name = cc.config.models.mamba #mamba
# name = cc.config.models.transformer #transformer
# name = cc.config.models.xlstm #xlstm
model = new_model(type)
model.load_state_dict(torch.load(name))
model.to('cuda')
model.eval()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
print(torch.cuda.memory_allocated())

101496482
1460876288


In [6]:
import torch
from train import filtered_logit
from collections import Counter

criterion = torch.nn.CrossEntropyLoss()
loader = processing.DatasetLoader(data_root)
dataloader = loader.get_dataloader_full()
model.eval()
total_loss = 0

intervals = {
    'pitch': (cc.start_idx['pitch'], cc.start_idx['dyn'] - 1),
    'dynamics': (cc.start_idx['dyn'], cc.start_idx['length'] - 1),
    'length': (cc.start_idx['length'], cc.start_idx['time'] - 1),
    'time': (cc.start_idx['time'], cc.start_idx['tempo'] - 1),
    'tempo': (cc.start_idx['tempo'], cc.vocab_size)
}

def get_successrate(correct, tries):
    success_rate = {}
    for key, val in tries.items():
        if val > 0:
            success_rate[key] = correct[key] / val
        else:
            success_rate[key] = None
    return success_rate

correct = {'pitch': 0, 'dynamics': 0, 'length': 0, 'time': 0, 'tempo': 0}
tries = {'pitch': 0, 'dynamics': 0, 'length': 0, 'time': 0, 'tempo': 0}

# Lists to store incorrect predictions
incorrect_src = []
incorrect_meta = []
incorrect_trg = []

for batch_idx, (src, trg, meta) in enumerate(dataloader):
    if no_meta:
        meta = torch.zeros_like(meta, device="cuda")
    with torch.no_grad():
        output = model(src, meta)
    filtered_output = filtered_logit(src, output)
    answers = trg[:, -1]
    logits_last = filtered_output[:, -1, :]

    guesses = logits_last.argmax(-1)
    
    filtered_output = filtered_output.reshape(-1, cc.vocab_size)

    for idx, (guess, answer) in enumerate(zip(guesses, answers)):
        for key, (low, high) in intervals.items():
            if low <= answer < high:
                tries[key] += 1
                if guess == answer:
                    correct[key] += 1
                else:
                    # Store incorrect predictions
                    incorrect_src.append(src[idx].cpu().numpy())
                    incorrect_meta.append(meta[idx].cpu().numpy())
                    incorrect_trg.append(trg[idx].cpu().numpy())

    trg = trg.view(-1)
    loss = criterion(filtered_output, trg)
    total_loss += loss.item()

    if (batch_idx + 1) % cc.config.values.eval_interval == 0:
        msg = f'{loss.item():.4f}'
        print(f'Step: {batch_idx+1}, Loss: {msg}')

avg_loss = total_loss / len(dataloader)
msg = f'Average Loss: {avg_loss:.4f}'
print(msg)

Step: 10, Loss: 1.0875
Step: 20, Loss: 0.9276
Step: 30, Loss: 0.9646
Step: 40, Loss: 0.9754
Step: 50, Loss: 1.0045
Step: 60, Loss: 1.3647
Step: 70, Loss: 1.6482
Step: 80, Loss: 2.0387
Step: 90, Loss: 1.0358
Step: 100, Loss: 1.0929
Step: 110, Loss: 1.1526
Step: 120, Loss: 1.5139
Step: 130, Loss: 1.5464
Step: 140, Loss: 1.1122
Step: 150, Loss: 2.2183
Step: 160, Loss: 1.2594
Step: 170, Loss: 2.1753
Step: 180, Loss: 2.3460
Step: 190, Loss: 1.8657
Step: 200, Loss: 1.4838
Step: 210, Loss: 2.1248
Step: 220, Loss: 2.0936
Step: 230, Loss: 1.5294
Step: 240, Loss: 1.3827
Step: 250, Loss: 1.0004
Step: 260, Loss: 1.4216
Step: 270, Loss: 1.0448
Step: 280, Loss: 1.3653
Step: 290, Loss: 1.2297
Step: 300, Loss: 1.4299
Step: 310, Loss: 1.5453
Step: 320, Loss: 1.3601
Step: 330, Loss: 2.5918
Step: 340, Loss: 3.1759
Step: 350, Loss: 3.2407
Step: 360, Loss: 3.3740
Step: 370, Loss: 3.2056
Step: 380, Loss: 3.2717
Step: 390, Loss: 3.2536
Step: 400, Loss: 3.3385
Step: 410, Loss: 3.0725
Step: 420, Loss: 3.0924
S

In [7]:
get_successrate(correct, tries)

{'pitch': 0.3386411889596603,
 'dynamics': 0.7434367541766109,
 'length': 0.7089627391742196,
 'time': 0.8675496688741722,
 'tempo': 0.9678800856531049}