In [38]:
import pandas as pd

def calc_fn_pred(
    df_anno: pd.DataFrame,
    pred_residues: set,
):
    true_resi = set(zip(df_anno['seq_id'], df_anno['resi_domain_posi'], df_anno['resi']))
    fn_resi = true_resi - (true_resi & pred_residues) # labelled as true but pred as false
    df_fn_resi = df_anno[df_anno.apply(lambda row: (row['seq_id'], row['resi_domain_posi'], row['resi']) in fn_resi, axis=1)]
    
    # number of true sites
    site_to_resi = dict()
    for site, df_site in df_anno.groupby(['seq_id', 'metal_chain', 'metal_pdb_seq_num', 'metal_resi']):
        site_to_resi[site] = set(zip(df_site['seq_id'], df_site['resi_domain_posi'], df_site['resi']))

    fn_site = set()
    for k, v in site_to_resi.items():
        fn = v & fn_resi
        if len(fn) == len(v):
            fn_site.add(k)
    df_fn_site = df_anno[df_anno.apply(lambda row: (row['seq_id'], row['metal_chain'], row['metal_pdb_seq_num'], row['metal_resi']) in fn_site, axis=1)]
    
    return df_fn_resi, df_fn_site

def get_sites(df: pd.DataFrame):
    return set(zip(df['seq_id'], df['metal_chain'], df['metal_pdb_seq_num'], df['metal_resi']))

def calc_fn_freq_on_type(
    true_sites: set,
    fn_sites: set,
):
    def metal_to_num(sites: set):
        dict_metal_to_num = dict()
        for s in sites:
            metal = s[-1]
            if metal in dict_metal_to_num.keys():
                dict_metal_to_num[metal] += 1
            else:
                dict_metal_to_num[metal] = 1
        return dict_metal_to_num

    true_metal_num = metal_to_num(true_sites)
    fn_metal_num = metal_to_num(fn_sites)
    freq_fn = dict()
    for i in true_metal_num.keys():

        true_num = true_metal_num[i]
        try:
            fn_num = fn_metal_num[i]
        except:
            fn_num = 0
        freq_fn[i] = fn_num / true_num
        
    return freq_fn

In [39]:
df_anno = pd.read_table(f"{PROJECT_DIR}/dataset/transform/test_metalnet.tsv")
df_pred = pd.read_table("../../pred_pairs.tsv")

df_pred = df_pred[df_pred['filter_by_graph'] == 1]
pred_resi = set()
for i, row in df_pred.iterrows():
    pred_resi.add((row['seq_id'], row['resi_seq_posi_1'], row['resi_1']))
    pred_resi.add((row['seq_id'], row['resi_seq_posi_2'], row['resi_2']))

In [40]:
df_fn_resi, df_fn_site = calc_fn_pred(df_anno, pred_resi)

fn_sites = get_sites(df_fn_site)
all_sites = get_sites(df_anno)

len(df_fn_resi)
len(df_fn_site)
len(fn_sites)
len(all_sites)

594

503

242

687

In [42]:
fn_freq = calc_fn_freq_on_type(all_sites, fn_sites)
fn_freq

{'ZN': 0.21052631578947367,
 'CA': 0.4829059829059829,
 'MN': 0.15625,
 'SF4': 0.13333333333333333,
 'MG': 0.6329113924050633,
 'FES': 0.0,
 'CU': 0.3,
 'FE': 0.1,
 'NI': 0.8,
 'CO': 0.125,
 'F3S': 0.0}