In [79]:
from pathlib import Path
from omegaconf import OmegaConf
import numpy as np
import pandas as pd
import mlflow
from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    RocCurveDisplay,
    PrecisionRecallDisplay,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score
)

filepaths = OmegaConf.load("../configs/filepaths/base.yaml")
mlflow.set_tracking_uri(filepaths.tracking_uri)

Wrangle top model run ids

In [80]:
top_models = pd.read_csv(Path(filepaths.artifacts) / "250209_top_models.csv", sep=',')
top_models.head(10)

Unnamed: 0,Start Time,Duration,Run ID,Name,Source Type,Source Name,User,Status,X_d_transform,batch_norm,...,epoch,train_loss,val/accuracy,val/binary_precision,val/binary_recall,val/f1,val/mcc,val/prc,val/roc,val_loss
0,2025-02-07 13:21:04,3.9h,ff90895a6070499b8fdcdc0982526885,dazzling-rat-180,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.025103,0.809866,0.963687,0.631624,0.762649,0.653961,0.921851,0.912527,0.762649
1,2025-02-07 13:20:39,3.1h,8a155bc5673b4c6aa8667dd55755fe0b,zealous-eel-839,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.002602,0.809679,0.972779,0.624411,0.76016,0.656536,0.959718,0.961152,0.76016
2,2025-02-07 13:20:39,3.1h,5ca6ba7b8aec4b81a366e9138b9c09b0,smiling-colt-432,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.004807,0.797277,0.969662,0.600536,0.741214,0.635536,0.947565,0.949833,0.741214
3,2025-02-07 13:20:35,3.4h,2241478505f9497885d25b9a0dee3ffd,fortunate-panda-724,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.027877,0.95056,0.979499,0.919235,0.948325,0.902712,0.986876,0.986808,0.948325
4,2025-02-07 13:20:33,3.4h,e638683b53f84d0ba1c8f32794413e0e,burly-goat-264,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.044305,0.949125,0.977389,0.918329,0.946844,0.899774,0.984206,0.986065,0.946844
5,2025-02-07 13:20:14,7.3h,fb3c7a9061a744fba7dabf9d50092414,learned-steed-320,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,,...,24,1.224699,0.772193,0.953163,0.556861,0.702544,0.589274,0.884468,0.865595,0.702544
6,2025-02-07 13:20:04,6.6h,4bf93c6830744a4b8e6d8076f14a82e1,defiant-dolphin-191,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,,...,24,1.357588,0.939986,0.965952,0.910732,0.937423,0.881308,0.975768,0.973943,0.937423
7,2025-02-07 13:20:04,3.5h,ef2ce94281fb445ba304791f991e23cc,funny-rat-28,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,spn1560,FINISHED,,True,...,24,0.118465,0.944258,0.975382,0.910101,0.941493,0.890394,0.982819,0.982808,0.941493


In [81]:
outer_run_ids = top_models['Run ID'].to_list()
print(",".join(outer_run_ids))

ff90895a6070499b8fdcdc0982526885,8a155bc5673b4c6aa8667dd55755fe0b,5ca6ba7b8aec4b81a366e9138b9c09b0,2241478505f9497885d25b9a0dee3ffd,e638683b53f84d0ba1c8f32794413e0e,fb3c7a9061a744fba7dabf9d50092414,4bf93c6830744a4b8e6d8076f14a82e1,ef2ce94281fb445ba304791f991e23cc


In [82]:
search_cols = [
    "data/neg_multiple",
    "data/split_strategy",
    "training/pos_multiplier",
    "model/name",
    "model/d_h_encoder",
    "model/encoder_depth",
    "model/radius",
    "model/vec_len",
]

top_models.loc[:, search_cols]

Unnamed: 0,data/neg_multiple,data/split_strategy,training/pos_multiplier,model/name,model/d_h_encoder,model/encoder_depth,model/radius,model/vec_len
0,3,rcmcs,3,bom,300,6.0,,
1,3,rcmcs,3,rc_agg,300,4.0,,
2,3,rcmcs,3,rc_cxn,300,6.0,,
3,3,homology,3,rc_cxn,300,6.0,,
4,3,homology,3,rc_agg,300,4.0,,
5,3,rcmcs,3,mfp,300,,2.0,2048.0
6,3,homology,3,mfp,300,,2.0,2048.0
7,3,homology,3,bom,300,6.0,,


In [83]:
split_runs = []
for i, row in top_models.iterrows():
    conditions = ["params.'data/split_idx' != '-1'"]
    for col in search_cols:
        val = row[col]

        if pd.isna(val):
            continue
        
        if isinstance(val, float) and val % 1 == 0:
            val = int(val)

        conditions.append(f"params.'{col}' = '{val}'")

    filter_string = " AND ".join(conditions)
    split_runs.append(mlflow.search_runs(filter_string=filter_string))

In [84]:
inner_runs = pd.concat(split_runs)
inner_runs.head()

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.val/binary_recall,metrics.val/binary_precision,metrics.val/roc,metrics.val/f1,...,params.model/model,params.model/pred_head,params.training/n_epochs,params.data/subdir_patt,tags.mlflow.source.type,tags.mlflow.source.name,tags.mlflow.runName,tags.mlflow.user,params.model/radius,params.model/vec_len
0,af1e9fbcef4546e689652a15ef6123ff,0,FINISHED,file:///projects/p30041/spn1560/hiec/results/r...,2025-02-03 10:48:14.384000+00:00,2025-02-04 08:33:09.719000+00:00,0.773198,0.963145,0.95855,0.857535,...,mpnn_dim_red,DotSig,25,sprhea_v3_folded_pt_ns/rcmcs/3fold,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,grandiose-robin-923,spn1560,,
1,eb05684d7fd846eaa95654c7814abd40,0,FINISHED,file:///projects/p30041/spn1560/hiec/results/r...,2025-02-03 10:10:25.056000+00:00,2025-02-04 10:21:46.689000+00:00,0.668354,0.966711,0.946668,0.789959,...,mpnn_dim_red,DotSig,25,sprhea_v3_folded_pt_ns/rcmcs/3fold,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,gaudy-mule-125,spn1560,,
2,e2c936866af944a1aaa3e73380d2c072,0,FINISHED,file:///projects/p30041/spn1560/hiec/results/r...,2025-02-03 09:29:31.200000+00:00,2025-02-04 13:07:34.342000+00:00,0.595301,0.957135,0.942648,0.733743,...,mpnn_dim_red,DotSig,25,sprhea_v3_folded_pt_ns/rcmcs/3fold,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,mysterious-fawn-710,spn1560,,
0,db987cbcf4c6492b90fd2f1adab33b96,0,FINISHED,file:///projects/p30041/spn1560/hiec/results/r...,2025-02-01 02:06:29.565000+00:00,2025-02-01 13:56:03.526000+00:00,0.690844,0.970544,0.957237,0.806823,...,mpnn_dim_red,DotSig,25,sprhea_v3_folded_pt_ns/rcmcs/3fold,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,silent-seal-423,spn1560,,
1,04b0957ef18f4ed69ce0f114702c8f00,0,FINISHED,file:///projects/p30041/spn1560/hiec/results/r...,2025-02-01 01:46:49.096000+00:00,2025-02-02 05:38:13.628000+00:00,0.680673,0.968373,0.948078,0.79907,...,mpnn_dim_red,DotSig,25,sprhea_v3_folded_pt_ns/rcmcs/3fold,LOCAL,/home/spn1560/.conda/envs/hiec/lib/python3.11/...,valuable-chimp-577,spn1560,,


In [86]:
run_ids = top_models['Run ID'].to_list() + inner_runs['run_id'].to_list()
print(",".join(run_ids))
print(len(run_ids))

ff90895a6070499b8fdcdc0982526885,8a155bc5673b4c6aa8667dd55755fe0b,5ca6ba7b8aec4b81a366e9138b9c09b0,2241478505f9497885d25b9a0dee3ffd,e638683b53f84d0ba1c8f32794413e0e,fb3c7a9061a744fba7dabf9d50092414,4bf93c6830744a4b8e6d8076f14a82e1,ef2ce94281fb445ba304791f991e23cc,af1e9fbcef4546e689652a15ef6123ff,eb05684d7fd846eaa95654c7814abd40,e2c936866af944a1aaa3e73380d2c072,db987cbcf4c6492b90fd2f1adab33b96,04b0957ef18f4ed69ce0f114702c8f00,95963b479bb9434ba0c4fb43a05dc7a4,62d84f7e370843489adab8c62efa74ab,3d93a80556ae4538951ea0122c48bb84,4cddcf9c4aa4439a989a5832a3df20de,74c15732aa1747b09eef3f3a67826260,4b80162c6a6849b1ace18e889cfbe571,d320827e0c754801b76edf5efe82d8f9,25aaed1cb3b04b028724c775048c9f39,9c3608ffe8dd439598958350a2cb9fb2,6d850586e6674fd2b4eb520b1d3b7494,c749ed3999f4408aa8e39d0ec4e7ff4a,62491f3e711e4d8b8bbe2ac27c85c4db,e38a5553209a4d5c8063f0e2abb54d74,bc114570a3a24a6e8384666ba854d320,96036ba3999d437fb79560ec4b0e5f88,a785183783074fee877eb15f7bb2a337,81b9c114084446338974bcda6b36c4dc,47547b9e0b

Tune threshold

In [None]:
'''
Get target_output
Append to df
Iterate thru
Calc F1
Pick best
'''

In [None]:
thresholds = np.linspace(0, 1, num=100)
best_thresholds = {}
best_inner_f1s = {}
for oid, df in zip(outer_run_ids, split_runs):
    thresh_f1s = []
    y_logits = []
    for inner_id in df['run_id']:
        target_output_path = Path(filepaths.results) / 'predictions' / inner_id / 'target_output.parquet'
        target_output = pd.read_parquet(target_output_path)
        y_logits.append((target_output['y'], target_output['logits']))

    best_th = thresholds[0]
    best_f1 = 0
    for th in thresholds:

        # Get f1 for each split
        f1s = []
        for y, logits in y_logits:
            ypred = (logits > th).astype(np.int32)
            f1 = f1_score(y_true=y, y_pred=ypred, average='binary')
            f1s.append(f1)

        mean_f1 = sum(f1s) / len(f1s)
        if mean_f1 > best_f1:
            best_f1 = mean_f1
            best_th = th

    best_thresholds[oid] = best_th
    best_inner_f1s[oid] = best_f1

In [145]:
best_thresholds

{'ff90895a6070499b8fdcdc0982526885': 0.010101010101010102,
 '8a155bc5673b4c6aa8667dd55755fe0b': 0.010101010101010102,
 '5ca6ba7b8aec4b81a366e9138b9c09b0': 0.010101010101010102,
 '2241478505f9497885d25b9a0dee3ffd': 0.15151515151515152,
 'e638683b53f84d0ba1c8f32794413e0e': 0.12121212121212122,
 'fb3c7a9061a744fba7dabf9d50092414': 0.010101010101010102,
 '4bf93c6830744a4b8e6d8076f14a82e1': 0.05050505050505051,
 'ef2ce94281fb445ba304791f991e23cc': 0.26262626262626265}

In [153]:
for oid in outer_run_ids:
    target_output_path = Path(filepaths.results) / 'predictions' / oid / 'target_output.parquet'
    target_output = pd.read_parquet(target_output_path)
    y = target_output['y']
    logits = target_output['logits']
    ypred = (logits > best_thresholds[oid]).astype(np.int32)
    f1 = f1_score(y_true=y, y_pred=ypred, average='binary')
    name = top_models.loc[top_models["Run ID"] == oid, "model/name"].values[0]
    split_strategy = top_models.loc[top_models["Run ID"] == oid, "data/split_strategy"].values[0]
    print(f"{oid}, {name}, {split_strategy}:  {f1}")
    

ff90895a6070499b8fdcdc0982526885, bom, rcmcs:  0.8225960012694383
8a155bc5673b4c6aa8667dd55755fe0b, rc_agg, rcmcs:  0.8698524099494271
5ca6ba7b8aec4b81a366e9138b9c09b0, rc_cxn, rcmcs:  0.821209182023925
2241478505f9497885d25b9a0dee3ffd, rc_cxn, homology:  0.9527385637433657
e638683b53f84d0ba1c8f32794413e0e, rc_agg, homology:  0.9521246052253804
fb3c7a9061a744fba7dabf9d50092414, mfp, rcmcs:  0.7176850412531388
4bf93c6830744a4b8e6d8076f14a82e1, mfp, homology:  0.9397424269907492
ef2ce94281fb445ba304791f991e23cc, bom, homology:  0.9475248242371529
