In [6]:
import numpy as np 
import pandas as pd 
from collections import Counter, defaultdict
from tqdm import tqdm
from pathlib import Path
from src.panspace.utils import LogInfo
loginfo = LogInfo()

from cleanlab.dataset import (
    rank_classes_by_label_quality, 
    find_overlapping_classes,
    overall_label_health_score,
)

from cleanlab.count import get_confident_thresholds

In [7]:
PATH_EXP = Path(f"/data/bacteria/experiments/autoencoders/6mer/01_22_2024-2/cross-validation")

loss="binary_crossentropy"
hidden_activation="relu"
output_activation="sigmoid"

kfold=1

In [8]:
PATH_QUERY_RES = PATH_EXP.joinpath(f"{loss}-{hidden_activation}-{output_activation}-{kfold}-fold/test/query_results.csv")
df_query = pd.read_csv(PATH_QUERY_RES)

In [9]:
# score: average distance
cols_distance = [c for c in df_query.columns if c.startswith("distance")]

for col_d in cols_distance: 
    df_query[f"score_{col_d}"] = df_query[col_d].apply(lambda d: (2-d)/2)

df_query["ground_truth"] = df_query["ground_truth"].apply(lambda x: x.replace("[","").replace("]",""))


In [10]:

cols_label = [c for c in df_query.columns if c.startswith("label")]
cols_score = [c for c in df_query.columns if c.startswith("score")]


In [11]:
df_query["score_label"] = df_query[cols_label].apply(lambda row: Counter(row).most_common()[0][1] / 10.,axis=1)
df_query

Unnamed: 0.1,Unnamed: 0,ground_truth,sample_id_query,sample_id_0,label_0,distance_to_0,sample_id_1,label_1,distance_to_1,sample_id_2,...,score_distance_to_1,score_distance_to_2,score_distance_to_3,score_distance_to_4,score_distance_to_5,score_distance_to_6,score_distance_to_7,score_distance_to_8,score_distance_to_9,score_label
0,0,neisseria_gonorrhoeae,SAMN04624666,SAMEA104081187,neisseria_gonorrhoeae,0.000067,SAMEA3672287,neisseria_gonorrhoeae,0.000086,SAMEA3431930,...,0.999957,0.999956,0.999953,0.999952,0.999951,0.999951,0.999950,0.999949,0.999949,1.0
1,1,escherichia_coli,SAMN09784558,SAMN06029321,escherichia_coli,0.000084,SAMN08144280,escherichia_coli,0.000084,SAMEA4062301,...,0.999958,0.999952,0.999951,0.999950,0.999949,0.999948,0.999947,0.999947,0.999947,1.0
2,2,salmonella_enterica,SAMN09788966,SAMN04575359,salmonella_enterica,0.000044,SAMN09430533,salmonella_enterica,0.000046,SAMN06842097,...,0.999977,0.999977,0.999976,0.999976,0.999976,0.999976,0.999975,0.999975,0.999974,1.0
3,3,salmonella_enterica,SAMN07159320,SAMN07138696,salmonella_enterica,0.000091,SAMN06890573,salmonella_enterica,0.000094,SAMN10240765,...,0.999953,0.999949,0.999947,0.999946,0.999945,0.999944,0.999943,0.999943,0.999942,1.0
4,4,salmonella_enterica,SAMN07227128,SAMEA80635918,salmonella_enterica,0.000030,SAMEA2148364,salmonella_enterica,0.000035,SAMEA4759513,...,0.999983,0.999979,0.999978,0.999978,0.999978,0.999976,0.999976,0.999976,0.999976,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132274,132274,salmonella_enterica,SAMN07731205,SAMN10396305,salmonella_enterica,0.000030,SAMN09655213,salmonella_enterica,0.000034,SAMN09501203,...,0.999983,0.999983,0.999983,0.999983,0.999981,0.999981,0.999981,0.999979,0.999979,1.0
132275,132275,neisseria_meningitidis,SAMEA2570241,SAMEA2570265,neisseria_meningitidis,0.000101,SAMEA3503529,neisseria_meningitidis,0.000116,SAMEA2074075,...,0.999942,0.999934,0.999932,0.999932,0.999931,0.999928,0.999928,0.999928,0.999927,1.0
132276,132276,salmonella_enterica,SAMN03252939,SAMN08981889,salmonella_enterica,0.000040,SAMN07156430,salmonella_enterica,0.000040,SAMN04942592,...,0.999980,0.999978,0.999975,0.999975,0.999975,0.999974,0.999973,0.999973,0.999972,1.0
132277,132277,mycobacterium_tuberculosis,SAMEA2535176,SAMN08629129,mycobacterium_tuberculosis,0.000033,SAMN09090641,mycobacterium_tuberculosis,0.000033,SAMEA3233432,...,0.999983,0.999981,0.999980,0.999980,0.999980,0.999979,0.999978,0.999978,0.999977,1.0


In [38]:
k=0
info_scores = []
_labels = set()
for row in tqdm(df_query.to_dict("records")):

    info_row = defaultdict(float)
    for col_label, col_score in zip(cols_label,cols_score):
        label = row[col_label]
        label = label.replace("[","").replace("]","")
        score = row[col_score]
        info_row[label] += 1 # score
        _labels.update([label])

    info_scores.append(info_row)

100%|██████████| 132279/132279 [00:00<00:00, 277571.41it/s]


In [39]:
probs = pd.DataFrame(info_scores).fillna(0.) / 10.

In [40]:
# labels to integer
labels = list(set(probs.columns).union(_labels).union(df_query.ground_truth.tolist()))
labels.sort()
label2num = {label: num for num, label in enumerate(labels)}

# num_labels = [random.choice(range(len(probs.columns))) for x in df_query.ground_truth.tolist()]
num_labels = [label2num[label] for label in df_query.ground_truth.tolist()]

In [41]:
missing_cols_in_probs = set(labels) - set(probs.columns)

for col in missing_cols_in_probs:
    probs[col] = 0.0 

# _values = [{l: 0.0 for l in missing_cols_in_probs}]

  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col] = 0.0
  probs[col]

In [42]:
probs

Unnamed: 0,neisseria_gonorrhoeae,escherichia_coli,salmonella_enterica,staphylococcus_aureus,francisella_tularensis,xanthomonas_oryzae,mycobacterium_tuberculosis,streptococcus_pneumoniae,mesorhizobium_ciceri,campylobacter_coli,...,polaribacter_sp._kt25b,neisseria_sp._10022,bradyrhizobium_oligotrophicum,microbulbifer_thermotolerans,porphyrobacter_sp._caciam_03h1,mycoplasma_penetrans,halomonas_sp._gfaj-1,streptomyces_sp._wac00288,luteipulveratus_mongoliensis,beijerinckia_indica
0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132274,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
132275,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
132276,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
132277,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [43]:
len(labels), np.array(num_labels).max()

(2359, 2358)

In [44]:
ranks = rank_classes_by_label_quality(  
    labels= num_labels,  
    pred_probs=probs.to_numpy(dtype=np.float32),  
    class_names=labels,
)

  given_conditional_noise = given_label_noise / joint.sum(axis=1)  # p(y!=k, s=k) / p(s=k)
  true_conditional_noise = true_label_noise / joint.sum(axis=0)  # p(s!=k, y=k) / p(y=k)


In [45]:
ranks

Unnamed: 0,Class Name,Class Index,Label Issues,Inverse Label Issues,Label Noise,Inverse Label Noise,Label Quality Score
0,photobacterium_gaetbulicola,1584,2,0,1.0,,0.0
1,sphingobacterium_sp._21,1943,1,0,1.0,,0.0
2,sphaerochaeta_globosa,1940,2,0,1.0,,0.0
3,sodalis_glossinidius,1936,1,0,1.0,,0.0
4,snodgrassella_alvi,1935,1,0,1.0,,0.0
...,...,...,...,...,...,...,...
2354,xenorhabdus_nematophila,2336,0,0,,,
2355,xylanimonas_cellulosilytica,2337,0,0,,,
2356,yersinia_sp._fdaargos_228,2352,0,0,,,
2357,zobellella_denitrificans,2355,0,0,,,


In [46]:
overlap_classes = find_overlapping_classes(
    labels= num_labels,  
    pred_probs=probs.to_numpy(dtype=np.float32),  
    class_names=labels,
)

In [47]:
overlap_classes.query("`Class Name A` == 'escherichia_coli' and `Num Overlapping Examples`>0")

Unnamed: 0,Class Name A,Class Name B,Class Index A,Class Index B,Num Overlapping Examples,Joint Probability
1135,escherichia_coli,mycoplasma_pulmonis,815,1379,1,8e-06
1186,escherichia_coli,megasphaera_hexanoica,815,1222,1,8e-06
1630,escherichia_coli,salinicoccus_halodurans,815,1865,1,8e-06


In [48]:
label2num["escherichia_coli"]

815

In [49]:
health_score = overall_label_health_score(
    labels= num_labels,  
    pred_probs=probs.to_numpy(dtype=np.float32),  
)

 * Overall, about 97% (127,650 of the 132,279) labels in your dataset have potential issues.
 ** The overall label health score for this dataset is: 0.03.


In [50]:
health_score

0.034994216769101666

In [51]:
conf_thresholds = get_confident_thresholds(
    labels= num_labels,
    pred_probs=probs.to_numpy(dtype=np.float32),
)

In [63]:
np.array(num_labels).max(), len(set(num_labels))

(2358, 1494)

In [64]:
probs.shape

(132279, 2359)