In [1]:
import sys

sys.path.append('..')

In [2]:
import argparse
import textwrap
from collections import defaultdict, namedtuple
from pathlib import Path

import sklearn.metrics as metrics

from herbarium.pylib import db

In [3]:
DATA = Path('..') / 'data'

DB = DATA / 'angiosperms.sqlite'

## Get test results

In [4]:
results = defaultdict(list)

rows = db.rows_as_dicts(DB, "select * from tests")
for row in rows:
    key = (row["test_set"], row["split_set"])
    results[key].append(row)

## Sort results by split_set and f1

In [5]:
Results = namedtuple("Results", "split test tn fp fn tp acc prec recall f1")

In [6]:
by_split_set = []

for (test_set, split_set), rows in results.items():
    y_true = [r['target'] for r in rows]
    y_pred = [round(r['pred']) for r in rows]
    
    cm = metrics.confusion_matrix(y_true, y_pred)

    accuracy = metrics.accuracy_score(y_true, y_pred)
    f1 = metrics.f1_score(y_true, y_pred)
    prec = metrics.precision_score(y_true, y_pred)
    recall = metrics.recall_score(y_true, y_pred)
    
    by_split_set.append(
        Results(
            split=split_set,
            test=test_set,
            tn=cm[0, 0],
            fp=cm[0, 1],
            fn=cm[1, 0],
            tp=cm[1, 1],
            acc=accuracy,
            prec=prec,
            recall=recall,
            f1=f1,
        )
    )

In [7]:
by_split_set = sorted(by_split_set, key=lambda t: (t[0], -t[-1]))

## Stats per test

In [15]:
print(f'{"split_set":<20} {"test_set":<35} '
      f'{"TP":^4}   {"FP":^4}   {"TN":^4}   {"FN":^4}   '
      f'{"acc":^6}   {"prec":^6}   {"recall":^6}   {"f1":^6}')

print(f'{"---------":<20} {"--------":<35} '
      '----   ----   ----   ----   '
      '------   ------   ------   ------')

prev = ''
for s in by_split_set:

    if prev and prev != s[0]:
        print()

    print(
        f'{s.split:<20} {s.test:<35} '
        f'{s.tp:4d}   {s.fp:4d}   {s.tn:4d}   {s.fn:4d}   '
        f'{s.acc:0.4f}   {s.prec:0.4f}   {s.recall:0.4f}   {s.f1:0.4f}'
#         f'  {s.tp + s.tn + s.fp + s.fn:4d}'
    )

    prev = s.split

split_set            test_set                             TP     FP     TN     FN     acc      prec    recall     f1  
---------            --------                            ----   ----   ----   ----   ------   ------   ------   ------
flowering            b0_flowers_all_orders_2             2190     16    287    104   0.9538   0.9927   0.9547   0.9733
flowering            b0_flowers_all_orders_1             2170     16    287    124   0.9461   0.9927   0.9459   0.9688

flowering_2_orders   b3_flowers_2_orders_unfrozen_2_acc   636     12     11     19   0.9543   0.9815   0.9710   0.9762
flowering_2_orders   b1_flowers_2_orders_frozen_2_acc     626      2     21     29   0.9543   0.9968   0.9557   0.9758
flowering_2_orders   b1_flowers_2_orders_unfrozen_4_acc   645     22      1     10   0.9528   0.9670   0.9847   0.9758
flowering_2_orders   b1_flowers_2_orders_frozen_1_acc     624      1     22     31   0.9528   0.9984   0.9527   0.9750
flowering_2_orders   b3_flowers_2_orders_frozen