In [1]:
import os
import pandas as pd

In [10]:
####USER DEFINED VARIABLES####
r2_indir = '/Users/stephaniecrilly/Library/CloudStorage/Box-Box/kortemmelab/home/scrilly/helix_sliding/20250604_r2_hs_lib/metric_files/msd_exp_bbs/msd_r2'
r3_indir = '/Users/stephaniecrilly/Library/CloudStorage/Box-Box/kortemmelab/home/scrilly/helix_sliding/20250604_r2_hs_lib/metric_files/msd_exp_bbs/msd_r3'
outdir = '/Users/stephaniecrilly/Library/CloudStorage/Box-Box/kortemmelab/home/scrilly/helix_sliding/20250604_r2_hs_lib/metric_files/msd_exp_bbs'

seq_design_condition = 'MSD'

#filter cutoffs
strictest_rmsd_cutoff = 0.5
strict_rmsd_cutoff = 1.0
lenient_rmsd_cutoff = 1.5
plddt_cutoff = 80
pae_cutoff = 5
min0_motif_res_in_heptad = 7
min2_motif_res_in_heptad = 0

In [18]:
#MPNN for af2 backbones
input_bb_type = 'af2_bm01_loop'
seq_design_method = 'MPNN'

temps = ['01', '03', '05']

#make mpnn_log.txt file in outdir
mpnn_log_file = os.path.join(outdir, f'{seq_design_condition}_mpnn_log.txt')
with open(mpnn_log_file, 'w') as log_file:
    log_file.write(f'Processing MPNN metrics for {seq_design_condition}, {input_bb_type} backbones condition\n\n')
    log_file.write(f'Input directory: {r2_indir}\n\n')
    log_file.write(f'Output directory: {outdir}\n\n')

full_dfs_to_concat = []

for temp in temps:

    #load corresponding dataframes
    af2_df = pd.read_csv(os.path.join(r2_indir, f'af2_metrics_msd_r2_t{temp}.csv'))
    seqs_df = pd.read_csv(os.path.join(r2_indir, f'13632_ALFA_52_07144_ALFA_52_bm01_loop_t_{temp}_mpnn_ssd_new_seqs.csv'))
    socket_fil_df = pd.read_csv(os.path.join(r2_indir, f'MSD_13632_07144_socket_filtered_t_{temp}.csv'))

    thread_position = '52'
    loop = 'bm01'
    min0_bb_id = '13632'
    min2_bb_id = '07144'

    #add additional info
    af2_df['state_design'] = seq_design_condition
    af2_df['thread_position'] = thread_position
    af2_df['loop'] = loop
    af2_df['input_bb_type'] = input_bb_type
    af2_df['seq_design_method'] = seq_design_method
    af2_df['min0_bb_id'] = min0_bb_id
    af2_df['min2_bb_id'] = min2_bb_id
    af2_df[['seq_id', 'af2_model_info']] = af2_df['design_id'].str.split('_unrelaxed_', regex=True, expand=True)

    #print out number of seqs
    print(temp)
    print(f'seqs_df shape: {seqs_df.shape[0]}') #should be 1/5 of af2_df
    print(f'af2_df shape: {af2_df.shape[0]}')
    print(f'socket_fil_df shape: {socket_fil_df.shape[0]}') #should be same as af2_df, unless some sockets didn't run

    #merge af2 and seqs df
    af2_seqs_df = pd.merge(af2_df, seqs_df, how='inner', on='seq_id')

    #merge with socket filtered df
    full_df = pd.merge(af2_seqs_df, socket_fil_df, how='inner', on='design_id')
    full_df = full_df.loc[:, ~full_df.columns.str.contains('^Unnamed')]
    full_dfs_to_concat.append(full_df)

    #merge with all socket metrics for structure with cc
    socket_df = pd.read_csv(os.path.join(r2_indir, f'MSD_13632_07144_all_socket_outputs_t_{temp}.csv'))
    socket_df = socket_df.loc[:, ~socket_df.columns.str.contains('^Unnamed')]

    #merge with full_df
    cc_design_metrics_df = full_df.merge(socket_df, on=['design_id', 'socket_call', 'h1_seq', 'h1_reg', 'h2_seq', 'h2_reg', 'h1_non_canon_num_res', 'h2_non_canon_num_res'], how='right')
    print(f'socket_df shape: {socket_df.shape[0]}') #should reflect number of pdbs detected as cc
    cc_design_metrics_df.to_csv(f'{outdir}/{seq_design_condition}_{seq_design_method}_{thread_position}_{loop}_{input_bb_type}_{temp}_all_metrics.csv')

    with open(mpnn_log_file, 'a') as log_file:
        log_file.write(f'{temp} total seqs: {seqs_df.shape[0]}\n')
        log_file.write(f'{temp} af2 structures: {af2_df.shape[0]}\n')
        log_file.write(f'{temp} socket filtered structures: {socket_fil_df.shape[0]}\n')
        log_file.write(f'{temp} socket metrics: {socket_df.shape[0]} total\n\n')

#save master mpnn df
full_df = pd.concat(full_dfs_to_concat)
full_df.to_csv(f'{outdir}/{seq_design_condition}_{seq_design_method}_{input_bb_type}_{temp}_compiled_metrics.csv')        


01
seqs_df shape: 1000
af2_df shape: 5000
socket_fil_df shape: 5000
socket_df shape: 4787
03
seqs_df shape: 1000
af2_df shape: 5000
socket_fil_df shape: 5000
socket_df shape: 4585
05
seqs_df shape: 1000
af2_df shape: 5000
socket_fil_df shape: 5000
socket_df shape: 4105


In [19]:
#MPNN for af2 backbones
input_bb_type = 'af2_bm01_loop_redesigned'
seq_design_method = 'MPNN'

temps = ['01', '03', '05']

#make mpnn_log.txt file in outdir
mpnn_log_file = os.path.join(outdir, f'{seq_design_condition}_mpnn_log.txt')
with open(mpnn_log_file, 'a') as log_file:
    log_file.write(f'Processing MPNN metrics for {seq_design_condition}, {input_bb_type} backbones condition\n\n')
    log_file.write(f'Input directory: {r3_indir}\n\n')
    log_file.write(f'Output directory: {outdir}\n\n')

full_dfs_to_concat = []

for temp in temps:

    #load corresponding dataframes
    af2_df = pd.read_csv(os.path.join(r3_indir, f'af2_metrics_t_{temp}.csv'))
    seqs_df = pd.read_csv(os.path.join(r3_indir, f'13632_ALFA_52_07144_ALFA_52_bm01_loop_redesigned_t_{temp}_mpnn_ssd_new_seqs.csv'))
    socket_fil_df = pd.read_csv(os.path.join(r3_indir, f'MSD_13632_07144_bm01_loop_redesigned_t_{temp}_socket_filtered.csv'))

    thread_position = '52'
    loop = 'bm01'
    min0_bb_id = '13632'
    min2_bb_id = '07144'

    #add additional info
    af2_df['state_design'] = seq_design_condition
    af2_df['thread_position'] = thread_position
    af2_df['loop'] = loop
    af2_df['input_bb_type'] = input_bb_type
    af2_df['seq_design_method'] = seq_design_method
    af2_df['min0_bb_id'] = min0_bb_id
    af2_df['min2_bb_id'] = min2_bb_id
    af2_df[['seq_id', 'af2_model_info']] = af2_df['design_id'].str.split('_unrelaxed_', regex=True, expand=True)

    #print out number of seqs
    print(temp)
    print(f'seqs_df shape: {seqs_df.shape[0]}') #should be 1/5 of af2_df
    print(f'af2_df shape: {af2_df.shape[0]}')
    print(f'socket_fil_df shape: {socket_fil_df.shape[0]}') #should be same as af2_df, unless some sockets didn't run

    #merge af2 and seqs df
    af2_seqs_df = pd.merge(af2_df, seqs_df, how='inner', on='seq_id')

    #merge with socket filtered df
    full_df = pd.merge(af2_seqs_df, socket_fil_df, how='inner', on='design_id')
    full_df = full_df.loc[:, ~full_df.columns.str.contains('^Unnamed')]
    full_dfs_to_concat.append(full_df)

    #merge with all socket metrics for structure with cc
    socket_df = pd.read_csv(os.path.join(r3_indir, f'MSD_13632_07144_bm01_loop_redesigned_t_{temp}_all_socket_outputs.csv'))
    socket_df = socket_df.loc[:, ~socket_df.columns.str.contains('^Unnamed')]

    #merge with full_df
    cc_design_metrics_df = full_df.merge(socket_df, on=['design_id', 'socket_call', 'h1_seq', 'h1_reg', 'h2_seq', 'h2_reg', 'h1_non_canon_num_res', 'h2_non_canon_num_res'], how='right')
    print(f'socket_df shape: {socket_df.shape[0]}') #should reflect number of pdbs detected as cc
    cc_design_metrics_df.to_csv(f'{outdir}/{seq_design_condition}_{seq_design_method}_{thread_position}_{loop}_{input_bb_type}_{temp}_all_metrics.csv')

    with open(mpnn_log_file, 'a') as log_file:
        log_file.write(f'{temp} total seqs: {seqs_df.shape[0]}\n')
        log_file.write(f'{temp} af2 structures: {af2_df.shape[0]}\n')
        log_file.write(f'{temp} socket filtered structures: {socket_fil_df.shape[0]}\n')
        log_file.write(f'{temp} socket metrics: {socket_df.shape[0]} total\n\n')

#save master mpnn df
full_df = pd.concat(full_dfs_to_concat)
full_df.to_csv(f'{outdir}/{seq_design_condition}_{seq_design_method}_{input_bb_type}_{temp}_compiled_metrics.csv')        


01
seqs_df shape: 1000
af2_df shape: 5000
socket_fil_df shape: 5000
socket_df shape: 3864
03
seqs_df shape: 1000
af2_df shape: 4940
socket_fil_df shape: 4940
socket_df shape: 3929
05
seqs_df shape: 1000
af2_df shape: 4970
socket_fil_df shape: 4970
socket_df shape: 3834


In [20]:
passing_designs_to_concat = []
min0_passing_designs_to_concat = []
min2_passing_designs_to_concat = []

for file in os.listdir(outdir):
    if file.endswith('all_metrics.csv'):
        print(f"Processing file: {file}")

        # Read in the data
        df = pd.read_csv(os.path.join(outdir, file), index_col=0)
        print(df.shape)
        print(len(df['sequence'].unique()))

        #filter for plddt
        df = df.query('avg_plddt_no_loop > @plddt_cutoff', engine='python').copy()
        print(f"plddt filtered: {len(df['sequence'].unique())}")
        
        #filter for pae
        df = df.query('avg_pae_no_loop < @pae_cutoff', engine='python').copy()
        print(f"pae filtered: {len(df['sequence'].unique())}")

        #get prediction with lowest rmsd to each state
        lowest_rmsd_min0_df = df.sort_values('min0_all_rmsd_no_loop', ascending=True).drop_duplicates('sequence').sort_index()
        lowest_rmsd_min0_df = lowest_rmsd_min0_df.query('min0_all_rmsd_no_loop < @lenient_rmsd_cutoff').copy()
        print(f"rmsd filtered (min0 only): {len(lowest_rmsd_min0_df['sequence'].unique())}")

        lowest_rmsd_min2_df = df.sort_values('min2_all_rmsd_no_loop', ascending=True).drop_duplicates('sequence').sort_index()
        lowest_rmsd_min2_df = lowest_rmsd_min2_df.query('min2_all_rmsd_no_loop < @lenient_rmsd_cutoff').copy()
        print(f"rmsd filtered (min2 only): {len(lowest_rmsd_min2_df['sequence'].unique())}")

        intersect_df = pd.merge(lowest_rmsd_min0_df, lowest_rmsd_min2_df, how='inner', on='sequence', suffixes=('_min0', '_min2'))
        print(f"rmsd filtered: {len(intersect_df['sequence'].unique())}")
        passing_designs_to_concat.append(intersect_df)

        #get sequences that pass filters for only one state
        min0_only_df = lowest_rmsd_min0_df[~lowest_rmsd_min0_df['sequence'].isin(intersect_df['sequence'])]
        print(f"min0 only: {len(min0_only_df['sequence'].unique())}")
        min0_passing_designs_to_concat.append(min0_only_df)

        min2_only_df = lowest_rmsd_min2_df[~lowest_rmsd_min2_df['sequence'].isin(intersect_df['sequence'])]
        print(f"min2 only: {len(min2_only_df['sequence'].unique())}")
        min2_passing_designs_to_concat.append(min2_only_df)

all_passing_designs_df = pd.concat(passing_designs_to_concat)
min0_passing_designs_df = pd.concat(min0_passing_designs_to_concat)
min2_passing_designs_df = pd.concat(min2_passing_designs_to_concat)
print(all_passing_designs_df.shape)
print(len(all_passing_designs_df['sequence'].unique()))
all_passing_designs_df.to_csv(f'{outdir}/exp_bbs_msd_designs_passing.csv')
min0_passing_designs_df.to_csv(f'{outdir}/exp_bbs_msd_designs_passing_min0_only.csv')
min2_passing_designs_df.to_csv(f'{outdir}/exp_bbs_msd_designs_passing_min2_only.csv')

Processing file: MSD_MPNN_52_bm01_af2_bm01_loop_01_all_metrics.csv
(4787, 44)
991
plddt filtered: 989
pae filtered: 956
rmsd filtered (min0 only): 888
rmsd filtered (min2 only): 80
rmsd filtered: 15
min0 only: 873
min2 only: 65
Processing file: MSD_MPNN_52_bm01_af2_bm01_loop_redesigned_01_all_metrics.csv
(3864, 44)
903
plddt filtered: 886
pae filtered: 674
rmsd filtered (min0 only): 587
rmsd filtered (min2 only): 101
rmsd filtered: 17
min0 only: 570
min2 only: 84
Processing file: MSD_MPNN_52_bm01_af2_bm01_loop_05_all_metrics.csv
(4105, 44)
944
plddt filtered: 924
pae filtered: 786
rmsd filtered (min0 only): 550
rmsd filtered (min2 only): 272
rmsd filtered: 39
min0 only: 511
min2 only: 233
Processing file: MSD_MPNN_52_bm01_af2_bm01_loop_redesigned_03_all_metrics.csv
(3929, 44)
912
plddt filtered: 899
pae filtered: 714
rmsd filtered (min0 only): 543
rmsd filtered (min2 only): 196
rmsd filtered: 28
min0 only: 515
min2 only: 168
Processing file: MSD_MPNN_52_bm01_af2_bm01_loop_03_all_metric