In [100]:
from compare_wandb import load_attention_msa_runs
import numpy as np
from tqdm import tqdm
from pathlib import Path
import wandb

from mogwai.parsing import read_contacts
from mogwai.metrics import precisions_in_range
import torch
import seaborn as sns

from mogwai.utils.functional import apc

api = wandb.Api()


In [149]:
dict_df = load_attention_msa_runs(['protbert_bfd'])
df = dict_df['protbert_bfd']

100%|██████████| 748/748 [00:00<00:00, 44102.16it/s]

l37wrnsa has 748 runs





In [150]:
df.shape

(723, 16)

In [148]:
# # download protbert predicted contacts from wandb. takes several minutes
# entity = 'proteins'
# project = 'gremlin-contacts'
# pred_filename = 'predicted_contacts.npy'
# for index, row in tqdm(df.iterrows()):
#     run_id = row["run_id"]
#     pdb = row["pdb"]
#     run = api.run(f'{entity}/{project}/{run_id}')
#     downpath = Path(f'./wandb_download/{entity}/{run_id}')
#     downpath.mkdir(parents=True, exist_ok=True)
#     run.file(pred_filename).download(downpath, replace=True)

In [160]:
metrics_keys = ['short_auc', 'short_pr_at_l', 'short_pr_at_l_2', 'short_pr_at_l_5', 'short_pr_at_l_10', 'medium_auc', 'medium_pr_at_l', 'medium_pr_at_l_2', 'medium_pr_at_l_5', 'medium_pr_at_l_10', 'long_auc', 'long_pr_at_l', 'long_pr_at_l_2', 'long_pr_at_l_5', 'long_pr_at_l_10']

In [161]:
# compute precisions

rows = []
model = 'protbert-bfd'
pred_filename = 'predicted_contacts.npy'
for index, row in tqdm(df.iterrows()):
    metrics = {}
    run_id = row["run_id"]
    pdb = row["pdb"]
    metrics['pdb'] = pdb
    run = api.run(f'{entity}/{project}/{run_id}')
    downpath = Path(f'./wandb_download/{entity}/{run_id}')
    npz_filename = f'data/npz/{pdb}.npz'
    targets = torch.from_numpy(read_contacts(npz_filename))
    with open(downpath / pred_filename, 'rb') as f:
        # do apc
        predictions = apc(torch.from_numpy(np.load(f)))
        short = precisions_in_range(predictions, targets, minsep=6, maxsep=12)
        for k, v in short.items():
            metrics[f'short_{k}'] = float(v.squeeze())
        medium = precisions_in_range(predictions, targets, minsep=12, maxsep=24)
        for k, v in medium.items():
            metrics[f'medium_{k}'] = float(v.squeeze())
        long = precisions_in_range(predictions, targets, minsep=24)
        for k, v in long.items():
            metrics[f'long_{k}'] = float(v.squeeze())
    rows.append(metrics)

723it [00:04, 154.88it/s]


In [162]:
import pandas as pd
new_metrics_df = pd.DataFrame.from_records(rows)

In [163]:
merged_df = df.merge(new_metrics_df, on='pdb')

In [164]:
import pickle as pkl
with open('protbert_bfd_metrics_df.pkl', 'wb') as f:
    pkl.dump(merged_df, f)

In [156]:
merged_df

Unnamed: 0,sweep_name,pdb,pdb_idx,msa_length,num_seqs,run_state,pr_at_L,pr_at_L_apc,pr_at_L_5,pr_at_L_5_apc,...,medium_auc,medium_pr_at_l,medium_pr_at_l_2,medium_pr_at_l_5,medium_pr_at_l_10,long_auc,long_pr_at_l,long_pr_at_l_2,long_pr_at_l_5,long_pr_at_l_10
0,protbert_bfd,2w3o_1_A,747,100.0,17351.0,finished,0.750000,0.750000,0.800000,0.850000,...,0.704679,0.570000,0.700000,0.850000,0.900000,0.556302,0.370000,0.520000,0.750000,1.000000
1,protbert_bfd,1eqz_1_B,746,108.0,4997.0,finished,0.175926,0.166667,0.171875,0.187500,...,0.241490,0.157407,0.240741,0.285714,0.400000,0.190753,0.101852,0.185185,0.285714,0.300000
2,protbert_bfd,3no0_1_A,745,276.0,9408.0,finished,0.630435,0.641304,0.757576,0.763636,...,0.375726,0.271739,0.304348,0.527273,0.777778,0.740657,0.561594,0.768116,0.909091,0.925926
3,protbert_bfd,1xju_1_A,744,156.0,2110.0,finished,0.314103,0.320513,0.397849,0.408602,...,0.313058,0.192308,0.294872,0.451613,0.533333,0.391359,0.288462,0.371795,0.483871,0.533333
4,protbert_bfd,4ew5_1_B,743,97.0,2254.0,finished,0.226804,0.268041,0.241379,0.258621,...,0.191926,0.123711,0.208333,0.315789,0.111111,0.030036,0.051546,0.020833,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
718,protbert_bfd,1zli_1_A,4,306.0,2169.0,finished,0.532680,0.522876,0.677596,0.655738,...,0.339267,0.176471,0.281046,0.508197,0.800000,0.642519,0.493464,0.594771,0.852459,0.833333
719,protbert_bfd,1u3j_1_B,3,99.0,1167.0,finished,0.444444,0.484848,0.525424,0.559322,...,0.495119,0.323232,0.510204,0.684211,0.666667,0.428153,0.343434,0.428571,0.526316,0.444444
720,protbert_bfd,1rkt_1_B,2,200.0,24747.0,finished,0.545000,0.545000,0.641667,0.633333,...,0.402071,0.210000,0.350000,0.650000,0.850000,0.539373,0.380000,0.530000,0.700000,0.800000
721,protbert_bfd,3eud_1_A,1,99.0,22284.0,finished,0.525253,0.494949,0.593220,0.627119,...,0.460487,0.353535,0.428571,0.578947,0.666667,0.378670,0.272727,0.408163,0.473684,0.444444


In [157]:
df[df['pdb'] == '2w3o_1_A']

Unnamed: 0,sweep_name,pdb,pdb_idx,msa_length,num_seqs,run_state,pr_at_L,pr_at_L_apc,pr_at_L_5,pr_at_L_5_apc,auc,auc_apc,run_id,log_num_seqs,model,use_bias
0,protbert_bfd,2w3o_1_A,747,100.0,17351.0,finished,0.75,0.75,0.8,0.85,0.843321,0.852246,71zg0efm,9.761405,protbert_bfd,False


In [158]:
merged_df[merged_df['pdb'] == '2w3o_1_A']

Unnamed: 0,sweep_name,pdb,pdb_idx,msa_length,num_seqs,run_state,pr_at_L,pr_at_L_apc,pr_at_L_5,pr_at_L_5_apc,...,medium_auc,medium_pr_at_l,medium_pr_at_l_2,medium_pr_at_l_5,medium_pr_at_l_10,long_auc,long_pr_at_l,long_pr_at_l_2,long_pr_at_l_5,long_pr_at_l_10
0,protbert_bfd,2w3o_1_A,747,100.0,17351.0,finished,0.75,0.75,0.8,0.85,...,0.704679,0.57,0.7,0.85,0.9,0.556302,0.37,0.52,0.75,1.0


In [159]:
new_metrics_df[new_metrics_df['pdb'] == '2w3o_1_A']

Unnamed: 0,pdb,short_auc,short_pr_at_l,short_pr_at_l_2,short_pr_at_l_5,short_pr_at_l_10,medium_auc,medium_pr_at_l,medium_pr_at_l_2,medium_pr_at_l_5,medium_pr_at_l_10,long_auc,long_pr_at_l,long_pr_at_l_2,long_pr_at_l_5,long_pr_at_l_10
0,2w3o_1_A,0.649591,0.48,0.58,0.85,1.0,0.704679,0.57,0.7,0.85,0.9,0.556302,0.37,0.52,0.75,1.0
