In [None]:
import os
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from ctd.eval import eval_verification, calculate_metrics

In [None]:
def eval(task_path, descs_path, dist_type='l2'):
    with open(descs_path, 'rb') as f:
        descriptors = pickle.load(f)

    samples = pd.read_csv(task_path)

    return eval_verification(samples, descriptors, 1, verbose=False, dist_type=dist_type)

In [None]:
def show_pairs(results, data_path):
    for i, (p, pair) in enumerate(results):
        print(pair[:2])
        path1, path2, gt = pair
        print('Prediction:', p[0], 'GT:', gt == 1)
        img1 = Image.open(os.path.join(data_path, path1))
        img2 = Image.open(os.path.join(data_path, path2))

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(img1)
        axs[1].imshow(img2)
        # axs[0].text(10, -10, path1)
        # axs[1].text(10, -10, path2)
        # axs[0].text(10, -20, 'GT: {}, Line: {}'.format(gt, i+1))
        axs[0].axis('off')
        axs[1].axis('off')
        plt.tight_layout()
        plt.show()

In [None]:
results, predictions, pairs = eval('tasks/verification_test.csv', 'features/triplet_resnet18_test.p', dist_type='l2')

In [None]:
true_positives = [(p, pair) for p, pair in zip(predictions, pairs) if p and pair[2]]
true_negatives = [(p, pair) for p, pair in zip(predictions, pairs) if (not p) and (not pair[2])]
false_positives = [(p, pair) for p, pair in zip(predictions, pairs) if p and (not pair[2])]
false_negatives = [(p, pair) for p, pair in zip(predictions, pairs) if (not p) and pair[2]]

tp = len(true_positives)
tn = len(true_negatives)
fp = len(false_positives)
fn = len(false_negatives)
t = tp + tn
f = fp + fn
pre = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2 * pre * rec / (pre + rec)
acc = t / (t + f)

print('# true positives : {:4d}'.format(tp))
print('# true negatives : {:4d}'.format(tn))
print('# false positives: {:4d}'.format(fp))
print('# false negatives: {:4d}'.format(fn))
print('# total pairs    : {:4d}'.format(t + f))
print()
print('precision: {:0.4f}'.format(pre))
print('recall   : {:0.4f}'.format(rec))
print('f1       : {:0.4f}'.format(f1))
print('acc      : {:0.4f}'.format(acc))

In [None]:
show_pairs(false_positives[0:10], 'data')