In [None]:
import json
from glob import glob
from tqdm import tqdm
from Bio import SeqIO
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import shutil
import os
import warnings

# Filter Inpaints

In [None]:
scores = pd.DataFrame(columns=['name','pLDDT_full','pLDDT_inpaint','linker_start','linker_end'])
for i,trb in tqdm(enumerate(glob('<path/to/your>/inpainting_outputs/*.trb'))):
    data = np.load(trb,allow_pickle=True)
    
    scores.loc[i,'name'] = trb.split('/')[-1].replace('.trb','')
    scores.loc[i,'pLDDT_full'] = np.mean(data['lddt'])
    scores.loc[i,'pLDDT_inpaint'] = np.mean(data['inpaint_lddt'])
    
    linker_len = int(data['sampled_mask'][0].split(',')[-2].split('-')[0])
    scores.loc[i,'linker_start'] = len(data['lddt']) - 179 - linker_len + 1
    scores.loc[i,'linker_end'] = len(data['lddt']) - 179
    
scores.to_csv('<desired/path/to>/inpainting_scores.csv')
scores.head()

In [None]:
plt.figure()
sns.histplot(scores['pLDDT_full'],bins=np.linspace(0.75,0.95,25))

In [None]:
sns.histplot(scores['pLDDT_inpaint'],bins=np.linspace(0.4,0.8,25))

In [None]:
sns.jointplot(data=scores,x='pLDDT_inpaint',y='pLDDT_full',kind='hist')

In [None]:
filtered_full = scores.sort_values(by='pLDDT_full',ascending=False,ignore_index=True).head(20)
filtered_inpaint = scores.sort_values(by='pLDDT_inpaint',ascending=False,ignore_index=True).head(20)
sub_filtered = pd.concat([filtered_full,filtered_inpaint],axis=0).drop_duplicates('name')
filtered = pd.concat([filtered,sub_filtered],axis=0)

print(len(filtered))
filtered.head()

# Make Commands

In [None]:
warnings.filterwarnings("ignore")

fixed_dict = {}
for i,row in tqdm(filtered.iterrows()):
    
    seq = [a for a in SeqIO.parse(f'<path/to/your>/inpainting_outputs/{row["name"]}.pdb', 'pdb-atom')][0].seq
    
    if 'normal' in row['name']:
        fixed_res = list(range(1,row['linker_start'])) + list(range(row['linker_end']+1,len(seq)+1))
        
    else:
        fixed_res = list(range(len(seq)-179+1,len(seq)+1))

        trp_locs = np.where(np.array(list(seq[:-179]))=='W')[0]
        for trp_loc in trp_locs:
            fixed_res.append(int(trp_loc)+1)
    
    fixed_dict[row["name"]] = {'A':fixed_res}

In [None]:
with open('<desired/path/to>/MPNN_fixed_residues.jsonl', 'w') as f:
    f.write(json.dumps(fixed_dict) + '\n')

In [None]:
pdbs = filtered['name'].map(lambda x: f'<path/to/your>/inpainting_outputs/{x}.pdb')
batches = [pdbs[i:min(i+16,len(pdbs))] for i in range(0, len(pdbs), 16)]

In [None]:
for i,batch in enumerate(batches):
    os.mkdir(f'<desired/path/to>/MPNN_inputs/batch_{i}')
    for pdb in batch:
        fname = pdb.split('/')[-1]
        shutil.copy(pdb,f'<desired/path/to>/MPNN_inputs/batch_{i}/{fname}')

In [None]:
for batch in glob('<desired/path/to>/MPNN_inputs/batch*'):
    with open(f'{batch}_cmd.sh','w') as f:
        f.write('#!/bin/bash\n'+\
                '#SBATCH -p long\n'+\
                '#SBATCH --mem=2g\n'+\
                '#SBATCH -c 1\n'+\
                f'#SBATCH --output {batch}.out\n'+\
                'python <path/to/your>/proteinmpnn/helper_scripts/parse_multiple_chains.py '+\ #in protenMPNN installation
                f'--input_path={batch} --output_path={batch}_parsed.jsonl\n'+\
                f'python <path/to/your>/proteinmpnn/protein_mpnn_run.py --jsonl_path {batch}_parsed.jsonl '+\
                '--fixed_positions_jsonl <path/to/your>/MPNN_fixed_residues.jsonl --out_folder <desired/path/to>/MPNN_outputs '+\
                '--omit_AAs C --num_seq_per_target 16 --sampling_temp "0.1" --batch_size 8 --model_name "v_48_020"')