## Pipeline

* clean up pdb files
* Relax (done beforehand)
* Get interface metrics using the Rosetta InterfaceAnalyzeMover and SasaCalc

* Uses interface residue detection for design

# Preparation

The pdb files for 7Z0X and 6M0J were prepared as follows:
* 7Z0Z: the constant region of the Fab was removed to speed up calculations
* all HETATM and CONECT lines were stripped as the corresponding entities are not part of the interfaces

In [7]:
# Imports

import pandas as pd
import numpy as np
from IPython.display import display, HTML
import py3Dmol

from pyrosetta import *
from utils import *
from mpnn import *

from ComplexInfo import ComplexInfo
from StructureInfo import StructureInfo

init('-detect_disulf true -out:level 0 -mute all')

PyRosetta-4 2022 [Rosetta PyRosetta4.Release.python38.m1 2022.26+release.cd16be2910907fd8044a25e8aaf705b9474af667 2022-06-29T22:22:46] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.


1. Clean up pdb for Rosetta compatibility

In [8]:
# Setup 7Z0X

input_pdb = 'pdb/7Z0X_mod_rel_0002_1_0001.pdb'

# Renumber. This is particularly important for ProteinMPNN
input_pdb_renum = renumber_pdb(input_pdb)
print(f"Renumbered residues saved to {input_pdb_renum}.")

binder_chains = ['R']       # Input spike chain
target_chains = ['L', 'H']  # Input Ab chains

complex_info_7Z0X = ComplexInfo(input_pdb_renum, binder_chains, target_chains, name='7Z0X')

Renumbered residues saved to pdb/7Z0X_mod_rel_0002_1_0001_renum.pdb.


In [9]:
# Setup 6M0J

input_pdb = 'pdb/6M0J_mod_renum_AE_0001_renum_rel_0001_1_0001.pdb'

# Renumber. This is particularly important for ProteinMPNN
input_pdb_renum = renumber_pdb(input_pdb)
print(f"Renumbered residues saved to {input_pdb_renum}.")

binder_chains = ['E']   # Input spike chain
target_chains = ['A']   # Input ACE chain

complex_info_6M0J = ComplexInfo(input_pdb_renum, binder_chains, target_chains, name="6M0J")

Renumbered residues saved to pdb/6M0J_mod_renum_AE_0001_renum_rel_0001_1_0001_renum.pdb.


In [10]:
# Analyze interfaces

df = pd.DataFrame()

for complex_info in [complex_info_7Z0X, complex_info_6M0J]:
    complex_info.run_interface_analysis()
    df = pd.concat([df, pd.DataFrame.from_dict(complex_info.get_interface_stats(), orient='index', columns=[complex_info.name])], axis=1)

display(HTML("<h2>Rosetta Interface Analysis Results</h2>"))
df

Unnamed: 0,7Z0X,6M0J
dG,-51.36,-44.31
SASA,1162.95,1718.14
hyrophobic SASA,813.9,1280.17
num_res,50,73
unsats,3,3
packstat,0.0,0.0
hbond_E,-12.24,-9.55
residue_set,"(33, 50, 51, 52, 53, 54, 55, 56, 57, 58, 98, 1...","(1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 16, 17, 1..."


>Ideally, one would run relax, but it takes too long to do it here in the notebook:

```python
# Set up the score function
# Create a score function with coordinate constraints
scorefxn = get_score_function()

# Set up the FastRelax mover
relax = FastRelax()
relax.constrain_relax_to_start_coords(True)
relax.set_scorefxn(scorefxn)

# Perform relaxation
print("Relaxing structure...")
start_pose = pose.clone()

# Uncomment the next line to actually run the relaxation
# relax.apply(pose)

print("Relaxation complete.")
pyrosetta.rosetta.core.scoring.CA_rmsd(start_pose, pose)


# Structure 1 - 7Z0X

In [5]:
# Get the interface residues
complex_info = complex_info_7Z0X

pose = complex_info.get_pose()
binder_if_residues, target_if_residues, binder_if_set, target_if_set  = select_interface_residues(pose, complex_info.binder_chains, complex_info.target_chains)

# Print interface residues
print(f"Interface residues chains {''.join(complex_info.binder_chains)}:", resnums2pdb(pose, binder_if_residues))
print(f"Interface residues chains {''.join(complex_info.target_chains)}:", resnums2pdb(pose, target_if_residues))

# Write temporary pdb to load it back as a string

pose.dump_pdb('temp.pdb')
with open('temp.pdb','r') as pdb:
    pdb_string  = pdb.read()
os.remove('temp.pdb')

# R = resnums2pdb(pose,[i for i,m in enumerate(R_IF_set) if m==1], num_only=True)
# H = resnums2pdb(pose,[i for i,m in enumerate(H_IF_set) if m==1], num_only=True)
# L = resnums2pdb(pose,[i for i,m in enumerate(L_IF_set) if m==1], num_only=True)

display(HTML("<h2>Interface: 7Z0X</h2>"))
viewer = py3Dmol.view(width=800, height=600)
viewer.addModel(pdb_string, 'pdb')
viewer.setStyle({'cartoon': {'color': 'white'}})

# Highlight receptor interface residues in blue
for chain in complex_info.binder_chains:
    viewer.addStyle({'chain': chain, 'resi': binder_if_residues}, 
                    {'cartoon': {'color': 'blue'}, 'stick': {}})

# Highlight binder interface residues in red
for chain in complex_info.target_chains:
    viewer.addStyle({'chain': chain, 'resi': target_if_residues}, 
                    {'cartoon': {'color': 'red'}, 'stick': {}})

viewer.zoomTo()
viewer.show()

Interface residues chains R: ['361 R ', '378 R ', '380 R ', '381 R ', '382 R ', '383 R ', '384 R ', '386 R ', '388 R ', '389 R ', '390 R ', '391 R ', '392 R ', '393 R ', '394 R ', '395 R ']
Interface residues chains LH: ['33 H ', '50 H ', '52 H ', '53 H ', '54 H ', '56 H ', '58 H ', '98 H ', '100 H ', '101 H ', '102 H ', '110 H ', '111 H ', '112 H ', '157 L ', '159 L ', '218 L ', '221 L ', '222 L ', '224 L ']


# Structure 2 - 6M0J

In [6]:
# Get the interface residues
complex_info = complex_info_6M0J

pose = complex_info.get_pose()
binder_if_residues, target_if_residues, binder_if_set, target_if_set  = select_interface_residues(pose, complex_info.binder_chains, complex_info.target_chains)

# Print interface residues
print(f"Interface residues chains {''.join(complex_info.binder_chains)}:", resnums2pdb(pose, binder_if_residues))
print(f"Interface residues chains {''.join(complex_info.target_chains)}:", resnums2pdb(pose, target_if_residues))

# Write temporary pdb to load it back as a string

pose.dump_pdb('temp.pdb')
with open('temp.pdb','r') as pdb:
    pdb_string  = pdb.read()
os.remove('temp.pdb')

display(HTML("<h2>Interface: 6M0J</h2>"))
viewer = py3Dmol.view(width=800, height=600)
viewer.addModel(pdb_string, 'pdb')
viewer.setStyle({'cartoon': {'color': 'white'}})

# Highlight receptor interface residues in blue
for chain in complex_info.binder_chains:
    viewer.addStyle({'chain': chain, 'resi': binder_if_residues}, 
                    {'cartoon': {'color': 'blue'}, 'stick': {}})

# Highlight binder interface residues in red
for chain in complex_info.target_chains:
    viewer.addStyle({'chain': chain, 'resi': target_if_residues}, 
                    {'cartoon': {'color': 'red'}, 'stick': {}})

viewer.zoomTo()
viewer.show()

Interface residues chains E: ['668 E ', '670 E ', '682 E ', '712 E ', '714 E ', '718 E ', '720 E ', '721 E ', '738 E ', '740 E ', '741 E ', '749 E ', '751 E ', '752 E ', '754 E ', '758 E ', '761 E ', '763 E ', '765 E ', '766 E ', '767 E ', '768 E ', '770 E ']
Interface residues chains A: ['1 A ', '6 A ', '9 A ', '10 A ', '12 A ', '13 A ', '16 A ', '17 A ', '19 A ', '20 A ', '23 A ', '24 A ', '27 A ', '61 A ', '64 A ', '65 A ', '306 A ', '307 A ', '308 A ', '312 A ', '333 A ', '335 A ', '336 A ', '337 A ', '339 A ', '368 A ']


# Protein-MPNN

## Use all interface positions to make the final design
* use receptor interface positions for negative design
* use Ab interface for positive interface design
* use the unbound spike for positive design to make sure the mutations don't destabilize it too much

In [7]:
structs = []

for complex_info in [complex_info_7Z0X, complex_info_6M0J]:
    structs.append(StructureInfo(complex_info.input_pdb, binder_chains=complex_info.binder_chains, target_chains=complex_info.target_chains))

Select interface residues for ProteinMPNN
Select interface residues for ProteinMPNN


In [8]:
# Add holo spike (binder) chain to pose list for subsequent design
# We will use the spike structure from 7Z0X - chain R (structs[0])
# Split the pose by chains

base_struct = structs[0]
complex_pose = base_struct.pose
split_chains = complex_pose.split_by_chain()

# Get a pose with only the relvant chain (R)
for chain_pose in split_chains:
    if chain_pose.pdb_info().chain(1) in base_struct.binder_chains:
        spike_pose = chain_pose

holo_spike_struct = StructureInfo(spike_pose, binder_chains=base_struct.binder_chains, target_chains=[], name='holo_spike')
structs.append(holo_spike_struct)

Select interface residues for ProteinMPNN


In [9]:
# Set the holo spike's (binder's) designable interface to be the combination of Ab and ACE interface

# Get spike's interface residue subsets
ab_if_set = structs[0].binder_if_set
ace_if_set = structs[1].binder_if_set

print("Ab-binding interface residues on spike:\t\t", ab_if_set.count(1))
print("ACE-binding interface residues on spike:\t", ace_if_set.count(1))

Ab-binding interface residues on spike:		 16
ACE-binding interface residues on spike:	 23


In [10]:
# Create dictionaries with the interface subset split by chain 
# This is needed because MPNN will shuffle the chains around

ab_split_set_dict = split_residue_set_by_chain(ab_if_set, structs[0].chain_lengths, structs[0].chain_ids)
ace_split_set_dict = split_residue_set_by_chain(ace_if_set, structs[1].chain_lengths, structs[1].chain_ids)
ace_binder_set_dict = {chain: ace_split_set_dict[chain] for chain in ace_split_set_dict if chain in structs[1].binder_chains}
#ace_binder_set_dict

In [11]:
# Combine spike protein's interface subsets

# Create a new subset (vector1_bool) to store the interface combination (e.g., logical OR operation)
spike_combined_if_set = rosetta.utility.vector1_bool()

# Iterate over the elements and combine using a logical operation (e.g., OR)
for i in range(len(ab_split_set_dict['R'])):  # Note: vector1_bool is 1-indexed
    combined_value = 1 if (ab_split_set_dict['R'][i] or ace_split_set_dict['E'][i]) else 0  # Change this to 'and', 'xor', etc., if needed
    spike_combined_if_set.append(combined_value)
    #print(spike_combined_if_set)

# Print the resulting combined vector
print("Combined spike residue subset:", spike_combined_if_set.count(1))

# replace 'R' set with combined set

# merge the binder set with the other chains
#for chain in struct[0].chain_ids:


# Set designable residues to this combined set
structs[2].binder_if_set = spike_combined_if_set

Combined spike residue subset: 31


In [12]:
# Setup MPNN runs and run designs

for struct in structs:
    struct.setup_mpnn()
    struct.run_mpnn()

Setup Selectors
Fix all positions that are in the binder but not part of the binding interface
Prepare mpnn files
Fixed 125 positions in chain 0
Fixed 112 positions in chain 1
Fixed 178 positions in chain 2
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125, 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112, 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

In [13]:
# Combine probabilities

ab_probs = structs[0].probs_per_chain['R']
print(f"Non-zero probabilities for Ab interface at \t {len([p for p in ab_probs if sum(p) != 0])} positions")

ace_probs = structs[1].probs_per_chain['E']
print(f"Non-zero probabilities for ACE interface at \t {len([p for p in ace_probs if sum(p) != 0])} positions")

holo_probs = structs[2].probs_per_chain['R']
print(f"Non-zero probabilities for both interfaces at \t {len([p for p in holo_probs if sum(p) != 0])} positions")

Non-zero probabilities for Ab interface at 	 16 positions
Non-zero probabilities for ACE interface at 	 23 positions
Non-zero probabilities for both interfaces at 	 31 positions


In [14]:
# Combine probabilities with pre-set weights.
# The ACE interface probabilties get a negative weight to design against binding

p_ab = 0.6
p_ace = -0.2
p_holo = 0.6


assert p_ab + p_ace + p_holo == 1, "weights have to add up to 1"
combined_probs = p_ab*ab_probs + p_ace*ace_probs + p_holo*holo_probs 
print(f"Non-zero probabilities for combined interface at {len([p for p in combined_probs if sum(p) != 0])} positions")

# -> Now we have the probablities for the spike protein.


Non-zero probabilities for combined interface at 31 positions


In [15]:
# Sanitation
if np.isnan(combined_probs).any():
    raise ValueError("NaNs present in combined_probs even after normalization.")

combined_probs.shape

(194, 21)

In [16]:
# Write a fasta file with the final designed sequence

struct = structs[2]
with open(f"mpnn_output/seqs/pose_{os.path.basename(struct.pdb_file)[:-4]}.fa", 'a') as seqfile:
    seqfile.write(f">multi_design\n")
    seqfile.write(derive_sequence_from_probs(combined_probs, struct.pose.sequence(), method='argmax'))

print(">Sequence for spike protein")
print(structs[2].pose.sequence())

>Sequence for spike protein
TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCG


In [17]:
# Create sequences for all chains in each original PDB 
# This can be used for complex structure predictions with e.g. AlphaFold 

for struct in structs:
    split_probs = struct.probs_per_chain
    split_probs[struct.binder_chains[0]] = combined_probs
    complete_probs = np.vstack(split_probs.values())
    start = 0
    print(struct.pdb_file)
    for chain,probs in split_probs.items():
        end = start + len(probs)
        print(">", chain)
        print(derive_sequence_from_probs(probs, struct.pose.sequence()[start:end], method='argmax'))
        start = end
    print()


7Z0X_mod_rel_0002_1_0001_renum.pdb
> H
EVQLVESGGGLVQPGGSLRLSCAASGFTVSSNYMSWVRQAPGKGLEWVSAIYSGDSTYYADSVKGRFTISRHNPKNTLYLQMNSLRAEDTAVYYCARLVGALTNIVVSGDGGAFDIWGQGTMVTV
> L
SYELTQPASVSGSPGQSITISCTGTSSDVGSYNLVSWYQQHPGKAPKLMIYEVSKRPSGVSNRFSGSKSGNTASLTISGLQAEDEVDYYCCSYAGSSTWVFGGGTKLTVLGQ
> R
TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIKGSEVRQIAPGQTGQIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNKNYLYRKYRKSNLKPFERDISTEIYQANSVPCNGKTGYNCYSPLASYNFDPSNPPGDQPYRVVVLSFELLHAPATVCG

6M0J_mod_renum_AE_0001_renum_rel_0001_1_0001_renum.pdb
> A
STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQQNGSSVLSEDKSKRLNTILNTMSTIYSTGKVCNPDNPQECLLLEPGLNEIMANSLDYNERLWAWESWRSEVGKQLRPLYEEYVVLKNEMARANHYEDYGDYWRGDYEVNGVDGYDYSRGQLIEDVEHTFEEIKPLYEHLHAYVRAKLMNAYPSYISPIGCLPAHLLGDMWGRFWTNLYSLTVPFGQKPNIDVTDAMVDQAWDAQRIFKEAEKFFVSVGLPNMTQGFWENSMLTDPGNVQKAVCHPTAWDLGKGDFRILMCTKVTMDDFLTAHHEMGHIQYDMAYAAQPFLLRNGANEGFHEAVGEIMSLSAATPKHLKSIGLLSPDFQEDNETEINFLLKQALTIVGTLPFTYMLEKWRWMVFKGEIPKDQWMKKW

  complete_probs = np.vstack(split_probs.values())


In [None]:
# Give ESM a shot, although expectations are low

import requests
import json


url = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"

payload = { "sequence": structs[2].pose.sequence() }
headers = {
    "accept": "application/json",
    "content-type": "application/json",
    "authorization": "Bearer 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"
}

response = requests.post(url, json=payload, headers=headers)

# Write to PDB file
d = json.loads(response.text)
with open("design_esm.pdb",'w') as esm_file:
    esm_file.write(d['pdbs'][0])

# Approach 2 - Rosetta-based Filterscan

Here, we calculate the interrface ΔΔG upon mutations to any other amino acid.

By looking at Alanine mutations only, we effectively have an in silico alanine scan. At the same time, we get suggestions for beneficial and detrimental mutations.

> **Due to time restrictions, this will only be demonstrated and not done to completion.**

### Get interface ddGs
>This is a pretty standard protocol

1. Get interface dG using the InterfaceAnalyzerMover
2. Mutate
3. Relax 8 Angstrom around mutation site, incl repack the entire interface
4. get new dG using the InterfaceAnalyzerMover
5. Calculate ddG

In [7]:
import importlib
import utils
importlib.reload(utils)
from utils import *

In [None]:
%%time

# Example for one position on the spike-ACE complex

complex_info = complex_info_6M0J
input_pdb = complex_info.input_pdb

chain_1 = ''.join(complex_info.binder_chains)
chain_2 = ''.join(complex_info.target_chains)

chain_id = chain_1[0]
residue_number = 710

# We focus here on polar residues as Rosetta loves big hyrophobics, but they make the proteins greasy
# and sometimes lower expression yields.
amino_acids_to_scan = "AGVSTPHRKED"

ddG_results = mutate_to_all_amino_acids_parallel(input_pdb, 
    chain_id, residue_number, chain_1, chain_2, num_cores=6, amino_acids=amino_acids_to_scan)

# Print the results
display_dict = {aa: val['ddG'] for aa,val in ddG_results.items()}
ddG_df = pd.DataFrame(list(display_dict.items()), columns=['Amino Acid', 'ΔΔG (REU)'])
#ddG_df

In [12]:
ddG_df

Unnamed: 0,Amino Acid,ΔΔG (REU)
0,VAL,0.0
1,ALA,-1.184594
2,THR,-0.730225
3,GLY,-1.061884
4,PRO,-1.336816
5,SER,-0.488272
6,GLU,-1.224774
7,ARG,-1.107445
8,LYS,-1.633658
9,HIS,-1.032904
