## Metrics of the classification using k-NN with embeddings
`test_index.py`

In [98]:
import json
import pandas as pd

from pathlib import Path
from collections import Counter, namedtuple, defaultdict

from sklearn.metrics import (
    precision_score, 
    recall_score,
    accuracy_score, 
    balanced_accuracy_score
    )


Load data from experiment

In [99]:
PATH_TRAIN=Path("/data/bacteria/experiments/autoencoders/6mer/26122023-2")

# load labels for train+val and test
with open(f"{PATH_TRAIN}/split-train-val-test.json","r") as fp:
    datasets = json.load(fp)


In [105]:
# collect info in a dataframe
InfoLabels = namedtuple("InfoLabels",["label","dataset","count"])
counts = dict()
data = []
for ds in ["train","val","test"]:
    count = Counter(datasets["labels"][ds])
    for specie, count in count.items():
        data.append(
            InfoLabels(specie, ds, count)
        )

df_infolabels = pd.DataFrame(data)

# get dict with count for the test set for later evaluation
counts_test = dict()
for idx, sp, ds, count in df_infolabels.query("dataset == 'test'").to_records("record"):
    counts_test[sp] = count

counts_train = defaultdict(int)
for idx, sp, ds, count in df_infolabels.query("dataset != 'test'").to_records("record"):
    counts_train[sp] += count

#### load assigned labels by k-NN using embeddings and faiss index
- Columns named `consensus_<k>` correspond to the label assigned by majority vote using the `k` retrieved embeddings for each query.
- The column `GT` correspond to the true label of the query.

In [106]:
df = pd.read_csv(PATH_TRAIN.joinpath("test/test_index.tsv"),sep="\t", index_col=0)
df.head(5)

Unnamed: 0,GT,consensus_1,consensus_3,consensus_5,consensus_10,0,1,2,3,4,5,6,7,8,9
0,brucella_melitensis,dustbin,brucella_melitensis,brucella_melitensis,brucella_melitensis,dustbin,brucella_melitensis,brucella_melitensis,brucella_suis,brucella_melitensis,brucella_melitensis,brucella_melitensis,brucella_suis,brucella_melitensis,brucella_melitensis
1,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii,acinetobacter_baumannii
2,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_variicola,klebsiella_pneumoniae,klebsiella_variicola,klebsiella_pneumoniae
3,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae,klebsiella_pneumoniae
4,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus,campylobacter_fetus


In [121]:
classes = df.GT.unique()
classes = sorted(classes)
y_true, y_pred = df.GT, df.consensus_10

In [122]:
accuracy_score(y_true, y_pred), balanced_accuracy_score(y_true, y_pred), len(classes)

(0.9698362539500144, 0.9327673877718559, 90)

In [123]:
precision = precision_score(y_true, y_pred, average=None, labels=classes)
recall = recall_score(y_true, y_pred, average=None, labels=classes)

  _warn_prf(average, modifier, msg_start, len(result))


In [124]:
data_metrics = []
Metrics = namedtuple("Metrics", ["label","n_queries", "n_index", "precision","recall"])
for sp, prec, rec in zip(classes, precision, recall):
    n_queries= counts_test[sp]
    n_index = counts_train[sp]
    data_metrics.append(
        Metrics(sp, n_queries, n_index, prec, rec)
    )

In [125]:
pd.DataFrame(data_metrics).sort_values(by="precision").head(30)

Unnamed: 0,label,n_queries,n_index,precision,recall
79,streptococcus_sp_group_b,3,24,0.0,0.0
11,burkholderia_cenocepacia,4,35,0.5,0.5
10,brucella_suis,7,68,0.5,0.285714
9,brucella_melitensis,12,111,0.611111,0.916667
27,enterobacter_hormaechei,69,622,0.658824,0.811594
83,vibrio_shilonii,3,30,0.75,1.0
26,enterobacter_cloacae,90,814,0.802817,0.633333
75,streptococcus_mitis,16,140,0.833333,0.625
82,treponema_pallidum,5,47,0.833333,1.0
25,dustbin,63,564,0.866667,0.825397


In [128]:
df[df.consensus_10 == "vibrio_shilonii"]

Unnamed: 0,GT,consensus_1,consensus_3,consensus_5,consensus_10,0,1,2,3,4,5,6,7,8,9
202,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
1434,enterobacter_cloacae,enterobacter_cloacae,enterobacter_cloacae,enterobacter_cloacae,vibrio_shilonii,enterobacter_cloacae,enterobacter_cloacae,vibrio_vulnificus,enterobacter_cloacae,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
1840,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
2144,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii


In [129]:
df[df.GT == "vibrio_shilonii"]


Unnamed: 0,GT,consensus_1,consensus_3,consensus_5,consensus_10,0,1,2,3,4,5,6,7,8,9
202,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
1840,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
2144,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii,vibrio_shilonii
