In [1]:
import pandas as pd

In [2]:
metal_types = {
    'ZN', 'CA', 'MG', 'MN', 'FE',
    'CU', 'NI', 'CO', 'FES', 'SF4',
    'F3S'
}

def calc_metrics_for_resi_metal_type(
    df_anno: pd.DataFrame,
    df_pred: pd.DataFrame
) -> dict:

    type_to_anno_resi = dict()
    type_to_pred_resi = dict()
    for m in metal_types:
        type_to_anno_resi[m] = set()
        type_to_pred_resi[m] = set()

    for _, row in df_anno.iterrows():
        metal_type = row['metal_resi']
        resi_id = row['seq_id'], row['resi_domain_posi']
        type_to_anno_resi[metal_type].add(resi_id)

    for _, row in df_pred.iterrows():
        types = row['pred'].split(";")
        resi_id = row['seq_id'], row['resi_seq_posi']
        for t in types:
            type_to_pred_resi[t].add(resi_id)

    result = dict()
    f1_total = 0
    for m in metal_types:
        true_resi = type_to_anno_resi[m]
        pred_resi = type_to_pred_resi[m]
        intersection = true_resi & pred_resi

        m_result = dict()
        recall = len(intersection) / \
            len(true_resi) if len(true_resi) != 0 else 0
        precision = len(intersection) / \
            len(pred_resi) if len(pred_resi) != 0 else 0
        f1 = 2 * recall * precision / \
            (recall + precision) if (recall + precision) != 0 else 0
        m_result['precision'] = precision
        m_result['recall'] = recall
        m_result['f1'] = f1
        result[m] = m_result
        f1_total += f1
    
    result['f1_macro_avg'] = f1_total / len(metal_types)
    
    return result

In [3]:
df_anno = pd.read_table(f"{PROJECT_DIR}/dataset/transform/test_metalnet.tsv")

In [4]:
result = calc_metrics_for_resi_metal_type(
    df_anno=df_anno,
    df_pred=pd.read_table("../pred_metal_type.tsv")
)
pd.DataFrame(result).transpose()

Unnamed: 0,precision,recall,f1
MN,0.340278,0.544444,0.418803
F3S,0.5,0.666667,0.571429
CO,0.357143,0.25,0.294118
FES,0.903226,0.8,0.848485
CU,0.642857,0.642857,0.642857
SF4,0.813008,0.917431,0.862069
FE,0.22449,0.415094,0.291391
MG,0.341463,0.25,0.28866
NI,0.166667,0.128205,0.144928
CA,0.739336,0.541667,0.625251


In [5]:
result = calc_metrics_for_resi_metal_type(
    df_anno=df_anno,
    df_pred=pd.read_table("pred_metal_type_LMetalSite.tsv")
)
pd.DataFrame(result).transpose()

Unnamed: 0,precision,recall,f1
MN,0.060811,0.9,0.113924
F3S,0.0,0.0,0.0
CO,0.0,0.0,0.0
FES,0.0,0.0,0.0
CU,0.0,0.0,0.0
SF4,0.0,0.0,0.0
FE,0.0,0.0,0.0
MG,0.117391,0.482143,0.188811
NI,0.0,0.0,0.0
CA,0.36489,0.689236,0.477163


In [6]:
result = calc_metrics_for_resi_metal_type(
    df_anno=df_anno,
    df_pred=pd.read_table("pred_metal_type_new_motifs.tsv")
)
pd.DataFrame(result).transpose()

Unnamed: 0,precision,recall,f1
MN,0.071942,0.111111,0.087336
F3S,0.1,0.333333,0.153846
CO,0.0,0.0,0.0
FES,0.088816,0.771429,0.159292
CU,0.0,0.0,0.0
SF4,0.249135,0.66055,0.361809
FE,0.026012,0.169811,0.045113
MG,0.229299,0.214286,0.221538
NI,0.023256,0.076923,0.035714
CA,0.564706,0.166667,0.257373


In [7]:
result = calc_metrics_for_resi_metal_type(
    df_anno=df_anno,
    df_pred=pd.read_table("pred_metal_type_old_motifs.tsv")
)
pd.DataFrame(result).transpose()

Unnamed: 0,precision,recall,f1
MN,0.083333,0.033333,0.047619
F3S,0.0,0.0,0.0
CO,0.0,0.0,0.0
FES,0.082609,0.542857,0.143396
CU,0.0,0.0,0.0
SF4,0.280992,0.623853,0.387464
FE,0.0,0.0,0.0
MG,0.125,0.017857,0.03125
NI,0.0,0.0,0.0
CA,0.647727,0.098958,0.171687
