In [2]:
import os
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import wandb


In [3]:
def load_npz_data(file, run_path):
  ref = wandb.restore(file, run_path=run_path)
  parsed = np.load(ref.name, allow_pickle=True)
  os.remove(ref.name)
  return parsed

def calculate_mr1(results):
  no_match_weight = 165
  ranks = []
  for (work_id, matches) in results:
    result = np.argwhere(matches['work_id'] == str.encode(work_id)).squeeze()
    has_result = result.shape != (0,)
    if has_result:
      result = result + 1
    else:
      result = no_match_weight

    ranks.append(result)

  mr1 = np.mean(ranks)
  return mr1

def calculate_accuracy(results):
  correct = 0
  incorrect = 0
  for (work_id, matches) in results:
    result = np.argwhere(matches['work_id'] == str.encode(work_id)).squeeze()
    has_result = result.shape != (0,)
    if has_result and result == 0:
      correct += 1
    else:
      incorrect += 1

  accuracy = correct/(correct+incorrect)
  return accuracy

### Load the data

In [16]:
db = load_npz_data('db.npz', run_path="pasinducw/seq2seq-covers80-eval/2kbsx67u")

queries = [
  ["speed0.90", "pasinducw/seq2seq-covers80-eval/186mrl2e"],
  ["speed0.95", "pasinducw/seq2seq-covers80-eval/1elqrhtn"],
  ["speed0.99", "pasinducw/seq2seq-covers80-eval/pr3v3t6z"],
  ["speed1.01", "pasinducw/seq2seq-covers80-eval/10knlmlq"],
  ["speed1.05", "pasinducw/seq2seq-covers80-eval/2wxt16ox"],
  ["speed1.10", "pasinducw/seq2seq-covers80-eval/139ynjxd"],
  
  ["pitch-4", "pasinducw/seq2seq-covers80-eval/10gkvmlg"],
  ["pitch-3", "pasinducw/seq2seq-covers80-eval/35xed4uf"],
  ["pitch-2", "pasinducw/seq2seq-covers80-eval/36dht253"],
  ["pitch-1", "pasinducw/seq2seq-covers80-eval/2knwwhfk"],
  ["pitch+1", "pasinducw/seq2seq-covers80-eval/23nmkjl1"],
  ["pitch+2", "pasinducw/seq2seq-covers80-eval/6a26xmja"],
  ["pitch+3", "pasinducw/seq2seq-covers80-eval/1zg8s438"],
  ["pitch+4", "pasinducw/seq2seq-covers80-eval/3w5bnoyb"],
  
  ["noise0.01", "pasinducw/seq2seq-covers80-eval/2kbsx67u"],
  ["noise0.05", "pasinducw/seq2seq-covers80-eval/37lw9urg"],
  ["noise0.1", "pasinducw/seq2seq-covers80-eval/1jpriocb"],
  ["noise0.2", "pasinducw/seq2seq-covers80-eval/1maxoicn"],
  ["noise0.3", "pasinducw/seq2seq-covers80-eval/6n0yry0u"],
  ["noise0.4", "pasinducw/seq2seq-covers80-eval/17qw7c03"],
  ["noise0.5", "pasinducw/seq2seq-covers80-eval/7jf9wvok"],
]

query_data = [(query[0], load_npz_data('query_results.npz', run_path=query[1])) for query in queries]

## Analyze the Data

In [17]:
for query in query_data:
  mr1 = calculate_mr1(query[1])
  accuracy = calculate_accuracy(query[1])
  print("[%s]\tAccuracy: %.2f\tMR1: %f" % (query[0], accuracy, mr1))

[speed0.90]	Accuracy: 0.82	MR1: 1.481707
[speed0.95]	Accuracy: 0.88	MR1: 1.262195
[speed0.99]	Accuracy: 0.84	MR1: 1.371951
[speed1.01]	Accuracy: 0.82	MR1: 1.426829
[speed1.05]	Accuracy: 0.83	MR1: 1.567073
[speed1.10]	Accuracy: 0.85	MR1: 1.335366
[pitch-4]	Accuracy: 0.05	MR1: 22.195122
[pitch-3]	Accuracy: 0.08	MR1: 39.134146
[pitch-2]	Accuracy: 0.06	MR1: 61.030488
[pitch-1]	Accuracy: 0.37	MR1: 13.762195
[pitch+1]	Accuracy: 0.29	MR1: 15.737805
[pitch+2]	Accuracy: 0.02	MR1: 61.390244
[pitch+3]	Accuracy: 0.06	MR1: 52.036585
[pitch+4]	Accuracy: 0.10	MR1: 26.353659
[noise0.01]	Accuracy: 1.00	MR1: 1.000000
[noise0.05]	Accuracy: 1.00	MR1: 1.000000
[noise0.1]	Accuracy: 0.98	MR1: 1.024390
[noise0.2]	Accuracy: 0.93	MR1: 1.567073
[noise0.3]	Accuracy: 0.92	MR1: 1.121951
[noise0.4]	Accuracy: 0.73	MR1: 4.317073
[noise0.5]	Accuracy: 0.65	MR1: 7.353659
