In [2]:
import copy
import math
import os

import torch
import numpy as np
import matplotlib.pyplot as plt

import librosa
import librosa.display

from SAMAF import SAMAF
# from SinhalaSongsDataset import SinhalaSongsDataset
from EvaluationSinhalaSongsDataset import EvaluationSinhalaSongsDataset

In [3]:
def draw_mfccs(mfccs):
    plt.figure()
    librosa.display.specshow(mfccs.transpose(0,1).numpy(), x_axis="time")
    plt.colorbar()
    plt.title("MFCC")
    plt.tight_layout()

In [4]:
def make_index(dataloader, model_params, device):
    model = SAMAF(embedding_dim=196).to(device)
    model.load_state_dict(model_params)

    index = []

    def threshold(value):
        if value > 0:
            return True
        else:
            return False

    with torch.no_grad():
        for i, (song_ids, mfccs) in enumerate(dataloader):
            embeddings, _ = model(mfccs)
            embeddings = embeddings.detach().cpu().numpy()
            embeddings = np.vectorize(threshold)(embeddings).astype(bool)
            for j, song_id in enumerate(song_ids):
                for offset, embedding in enumerate(embeddings[j]):
                    index.append((embedding, song_id, offset))
    
    return index

In [5]:
def evaluate(dataloader, model_parms, device, index):
    model = SAMAF(embedding_dim=196).to(device)
    model.load_state_dict(model_parms)

    def threshold(value):
        if value > 0:
            return True
        else:
            return False
    
    def compute_and_get_best_matchings(index, hash):
        matchings = []
        best = 0
        for (key, music_id, _) in index:
            score = len(hash) - np.sum(np.logical_xor(hash, key))
            best = max(best, score)
            matchings.append((score, music_id))
        
        best_matchings = []
        for (score, music_id) in matchings:
            if score == best:
                best_matchings.append((score,music_id))
        return np.array(best_matchings)

    vectorized_threshold = np.vectorize(threshold)
    correct_matches = 0
    incorrect_matches = 0

    with torch.no_grad():
        for i, (song_ids, mfccs) in enumerate(dataloader):
            embeddings, _ = model(mfccs)
            embeddings = embeddings.detach().cpu().numpy()
            embeddings = vectorized_threshold(embeddings).astype(bool)
            for j, song_id in enumerate(song_ids):
                candidate_matchings = []
                for _, embedding in enumerate(embeddings[j]):
                    matchings = compute_and_get_best_matchings(index, embedding)
                    matchings = np.unique(np.array(matchings)[:,1])
                    candidate_matchings.extend(matchings)
                candidate_matchings = np.array(candidate_matchings)
                (matching_song_ids, counts) = np.unique(candidate_matchings, return_counts=True)
                matched_song_id = matching_song_ids[np.argmax(counts)]
                if song_id == matched_song_id:
                    correct_matches += 1
                else:
                    incorrect_matches += 1
            print("Current Accuracy", correct_matches/(correct_matches+incorrect_matches))
    return correct_matches, incorrect_matches

In [6]:
device = torch.device("cpu")
checkpoint = torch.load("../data/L1-D196-B20-E100-EXP2/snapshot-10.pytorch", map_location=device)
features_directory = "/home/pasinducw/Downloads/Research-Datasets/Sinhala-Songs/features"

index_dataset = EvaluationSinhalaSongsDataset(root_dir=features_directory, trim_seconds=10, indexing=True)
index_dataloader = torch.utils.data.DataLoader(index_dataset, shuffle=False)

query_dataset = EvaluationSinhalaSongsDataset(root_dir=features_directory, trim_seconds=10)
query_dataloader = torch.utils.data.DataLoader(query_dataset, batch_size=32, shuffle=False)


In [7]:
# Make the index
index = make_index(index_dataloader, checkpoint['best_model_weights'], device)

In [8]:
hashes = map(lambda x: x[0], index)
hashes = np.array(list(hashes))
unique = np.unique(hashes, axis=0)

print(hashes.shape)
print(unique.shape)


(7900, 196)
(2908, 196)


In [None]:
evaluate(query_dataloader, checkpoint['best_model_weights'], device, index)