# <center>**ProBASS: a language model with sequence and structural features for predicting the effect of mutations on binding affinity**</center>
---
Here we introduce a model (ProBASS) which is fine-tuned, incorporating features derived from both Protein Language models ESM-2 and ESM-IF1.This model is designed for the prediction of ddGbind values, which serve as indicators of both the sequence and structural attributes of the mutated protein complexes.

The model is an efficient way to predict the effect of mutations on protein binding affinity.

---

**Instructions for users on how to provide the PDB ID of the protein complex and the CSV file which contains the mutation information to Probass**

Please input the "PDB ID" of the Protein complex under the subcategory Input Data which is required to calculate the binding affinity of the mutations.

The user can specify the desired mutations for binding affinity calculation by providing the information in a CSV file named 'Input.csv'. This file should include five columns with the following headers: 'Mutated_chain', 'Partner_chain', 'Wild_type', 'Position', and 'Mutation'.

The 'Mutated_chain' and 'Partner_chain' define the interface of the protein complex. 'Wild_type' refers to the original amino acid in the protein complex, 'Position' indicates the location of the desired mutation, and 'Mutation' specifies the amino acid the user wishes to substitute for the wild type.


# Environment Set up for **ProBASS:**

In [1]:
%%capture
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu113
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cu113.html
!pip install biotite==0.33.0
!pip install catboost
!pip install git+https://github.com/facebookresearch/esm.git
!pip install requests
!pip install biopython

In [2]:
import os
import numpy as np
import pandas as pd
import catboost as cb
import torch
import esm
import scipy
from numpy import asarray
from numpy import savez_compressed
import requests
from Bio.PDB import PDBParser

In [3]:
%%bash

cd /content/

if [ ! -f ProBASS ]; then


    # delete the Cold-scanner/ directory if it already exists
    if [ -d "ProBASS/" ]; then
        rm -rf ProBASS/
    fi

    # download model
    git clone https://github.com/sagagugit/ProBASS --quiet
    touch ProBASS
fi

#Input Data

In [5]:
#@title Please provide the PDB ID and mutation information
PDB = '3OTJ'#@param {type:"string"}
try:
  from google.colab import drive
  drive.mount('/content/drive')
  from google.colab import files
  uploaded = files.upload()
except FileNotFoundError:
  print("ERROR: \n Uploading was not successful. Please restart and try to upload the complex again")


Mounted at /content/drive


Saving Input.csv to Input.csv


#Selecting Path

In [6]:
%%capture
%cd ProBASS
!cp /content/Input.csv /content/ProBASS

# Extracting embeddings from ESM2 and ESM-IF1

# Extracting Fasta files for wild type, partner chain and mutated PPI

In [7]:
%%capture
# Define the PDB code
PDB_code = PDB  # You can change this to any PDB code

# Construct the URL to download the PDB file
url = f'http://www.rcsb.org/pdb/download/downloadFile.do?fileFormat=pdb&structureId={PDB_code}'

# Send a GET request to fetch the PDB file
response = requests.get(url)

# Check if the request was successful
if response.status_code == 200:
    # Save the file to the local filesystem
    with open(f'{PDB_code}.pdb', 'wb') as file:
        file.write(response.content)
    print(f'{PDB_code}.pdb has been downloaded successfully.')
else:
    print(f'Failed to download {PDB_code}.pdb. Status code: {response.status_code}')

In [8]:
# Define residue name to single-letter code mapping
RESIDUE_NAME_TO_LETTER = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}

# Load the CSV file
input_csv = 'Input.csv'
df = pd.read_csv(input_csv)

# Define PDB code
PDB_code = PDB

# Read the PDB file
pdb_file = f'{PDB_code}.pdb'
parser = PDBParser(QUIET=True)
structure = parser.get_structure(PDB_code, pdb_file)

# Function to extract sequence for a given chain
def extract_sequence(chain_id):
    sequence = []
    for model in structure:
        for chain in model:
            if chain.get_id() == chain_id:
                for residue in chain:
                    resname = residue.get_resname()
                    if resname in RESIDUE_NAME_TO_LETTER:
                        sequence.append(RESIDUE_NAME_TO_LETTER[resname])
    return ''.join(sequence)

# Create FASTA files for each chain
for index, row in df.iterrows():
    mutated_chain_id = row['Mutated_chain']
    partner_chain_id = row['Partner_chain']

    mutated_sequence = extract_sequence(mutated_chain_id)
    partner_sequence = extract_sequence(partner_chain_id)

    # Save the sequences to FASTA files
    with open(f'{PDB_code}_wild.fasta', 'w') as f:
        f.write(f'> {PDB_code}_wild\n{mutated_sequence}\n')

    with open(f'{PDB_code}_partner.fasta', 'w') as f:
        f.write(f'> {PDB_code}_partner\n{partner_sequence}\n')

print(f'FASTA files for {PDB_code} have been created.')

FASTA files for 3OTJ have been created.


In [9]:
# Define the PDB code
PDB_code = PDB

# Load the Input.csv file
input_csv = 'Input.csv'
df = pd.read_csv(input_csv)

# Load the wild-type sequence from the FASTA file
def read_fasta(fasta_file):
    with open(fasta_file, 'r') as f:
        lines = f.readlines()
    sequence = ''.join(line.strip() for line in lines[1:])
    return sequence

wild_sequence = read_fasta(f'{PDB_code}_wild.fasta')

# Convert 1-based position to 0-based index for Python
def position_to_index(position):
    return position - 1

# Create a list to store FASTA entries
fasta_entries = []

for index, row in df.iterrows():
    mutated_sequence = list(wild_sequence)
    position = row['Position']
    mutation = row['Mutation']

    # Convert position to 0-based index
    idx = position_to_index(position)

    # Apply mutation
    if 0 <= idx < len(mutated_sequence):
        mutated_sequence[idx] = mutation

    # Convert list back to string
    mutated_sequence_str = ''.join(mutated_sequence)

    # Format the FASTA header
    fasta_header = f'> {PDB_code}_{position}{mutation}\n'

    # Add to FASTA entries
    fasta_entries.append(fasta_header + mutated_sequence_str + '\n')

with open(f'{PDB_code}.fasta', 'w') as f:
    f.writelines(fasta_entries)

print(f'Mutated FASTA file for {PDB_code} has been created with formatted headers.')



Mutated FASTA file for 3OTJ has been created with formatted headers.


# Extract sequence embeddings and Structural embeddings

In [11]:
%%capture
#Extrat seqeunce embeddings from ESM2 seperately for wild type, partner and mutated PPI
!python extract.py esm2_t33_650M_UR50D 3OTJ.fasta 3OTJ_esm2 --repr_layers 0 32 33 --include mean per_tok

!python extract.py esm2_t33_650M_UR50D 3OTJ_wild.fasta 3OTJ_esm2_wild --repr_layers 0 32 33 --include mean per_tok

!python extract.py esm2_t33_650M_UR50D 3OTJ_partner.fasta 3OTJ_esm2_partner --repr_layers 0 32 33 --include mean per_tok

model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()

fpath = PDB + '.pdb'
input_file = 'Input.csv'
df = pd.read_csv(input_file)

chain_ids = list(set(df['Mutated_chain'].tolist() + df['Partner_chain'].tolist()))
structure = esm.inverse_folding.util.load_structure(fpath, chain_ids)
coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)

print(f'Loaded chains: {list(coords.keys())}\n')

for chain_id in chain_ids:
    print(f'Chain {chain_id} native sequence:')
    print(native_seqs[chain_id])
    print('\n')


mutated_chain_ids = df['Mutated_chain'].unique()


target_chain_id = mutated_chain_ids[0]
rep = esm.inverse_folding.multichain_util.get_encoder_output_for_complex(model, alphabet, coords, target_chain_id)
len(coords), rep.shape
print(len(coords), rep.shape)
print(target_chain_id)

numpy_rep =rep.detach().numpy()
print(numpy_rep)
np.savez(f"inverse_{PDB}.npz", data=numpy_rep)

# Run ProBASS to predict the ΔΔG values

In [12]:
PDBS = PDB_code = [PDB]

def exctracting_embeddings_esm2(pdb):
    mutations2= []
    Xs2 = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2):
        scaled_effect2 = header2.split('|')[-1]
        mutations2.append(scaled_effect2)
        fn = f'{EMB_PATH2}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2.append(embs2['representations'][33])
    Xs2 = torch.stack(Xs2, dim=0).numpy()

    return Xs2, mutations2





def exctracting_embeddings_esm2_wild(pdb):
    mutations2_w= []
    Xs2_w = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2_w):
        scaled_effect2_w = header2.split('|')[-1]
        mutations2_w.append(scaled_effect2_w)
        fn = f'{EMB_PATH2_w}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2_w.append(embs2['representations'][33])
    Xs2_w = torch.stack(Xs2_w, dim=0).numpy()

    return Xs2_w




def exctracting_embeddings_esm2_bind(pdb):
    mutations2_b= []
    Xs2_b = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2_b):
        scaled_effect2_b = header2.split('|')[-1]
        mutations2_b.append(scaled_effect2_b)
        fn = f'{EMB_PATH2_b}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2_b.append(embs2['representations'][33])
    Xs2_b = torch.stack(Xs2_b, dim=0).numpy()

    return Xs2_b

def exctracting_embeddings_1f(pdb):
    temp= np.load(inverse_path)
    inverse= temp['data']


    average_mean_embedding = np.mean(inverse, axis=0)
    average_mean_embedding.shape
    inverse_mean_reshape = average_mean_embedding.reshape([1, 512])
    inverse_mean_reshape.shape




    return inverse_mean_reshape

In [13]:
%%capture
ddg_values = []
embeddings = []
for pdb in PDBS:
    FASTA_PATH = "/content/ProBASS/{}.fasta".format(pdb)
    EMB_PATH = "/content/ProBASS/{}_1V".format(pdb)
    FASTA_PATH2 = "/content/ProBASS/{}.fasta".format(pdb)
    EMB_PATH2 = "/content/ProBASS/{}_esm2".format(pdb)
    FASTA_PATH_w = "/content/ProBASS/{}_wild.fasta".format(pdb)
    EMB_PATH_w = "/content/ProBASS/{}_1V_wild".format(pdb)
    FASTA_PATH2_w = "/content/ProBASS/{}_wild.fasta".format(pdb)
    EMB_PATH2_w = "/content/ProBASS/{}_esm2_wild".format(pdb)
    FASTA_PATH_b = "/content/ProBASS/{}_partner.fasta".format(pdb)
    EMB_PATH_b = "/content/ProBASS/{}_1V_partner".format(pdb)
    FASTA_PATH2_b = "/content/ProBASS/{}_partner.fasta".format(pdb)
    EMB_PATH2_b = "/content/ProBASS/{}_esm2_partner".format(pdb)
    inverse_path = '/content/ProBASS/inverse_{}.npz'.format(pdb)
    csv_path = "/content/ProBASS/{}.csv".format(pdb)
    Xs2, mutations2= exctracting_embeddings_esm2(pdb)
    Xs2_w= exctracting_embeddings_esm2_wild(pdb)
    Xs2_w=np.tile(Xs2_w, (len(Xs2), 1, 1))
    Xs2_b=exctracting_embeddings_esm2_bind(pdb)
    Xs2_b=np.tile(Xs2_b, (len(Xs2), 1, 1))
    inverse=exctracting_embeddings_1f(pdb)
    inverse=np.tile(inverse, (len(Xs2), 1))
    mutant_and_partner_together_esm2 = np.concatenate([Xs2_b, Xs2], axis =1)

    wild_type_and_partner_together_esm2 = np.concatenate([Xs2_b, Xs2_w], axis =1)
    mutant_and_partner_together_esm2_mean=np.mean(mutant_and_partner_together_esm2, axis=1)
    wild_type_and_partner_together_esm2_mean=np.mean(wild_type_and_partner_together_esm2, axis=1)
    ddg_1v = np.subtract(mutant_and_partner_together_esm2_mean, wild_type_and_partner_together_esm2_mean)

    ddg_esm2_with_inverse = np.concatenate([ddg_1v, inverse], axis =1)
    embeddings.append(ddg_esm2_with_inverse)

In [14]:
import numpy as np


ddg_length = len(embeddings[0])
ddg_values = [0] * ddg_length


flattened_list = ddg_values


extracted_array = embeddings[0]
Xs_test = extracted_array
ys_test = flattened_list

np.savez('test.npz', data=Xs_test, label=ys_test)

In [15]:
temp = np.load('test.npz')
X_test, test_y = temp['data'], temp['label']

In [16]:
#Load your Input.csv
input_data = pd.read_csv('Input.csv')
#Load your model and make predictions
model = cb.CatBoostRegressor()
loaded_model1 = cb.CatBoostRegressor()
loaded_model1.load_model('Probass_model.cbm')

ypred = loaded_model1.predict(X_test)


input_data['Mutation'] = input_data['Wild_type'] + input_data['Position'].astype(str) + input_data['Mutation']

predicted_df = pd.DataFrame({'Mutation': input_data['Mutation'], 'predicted_value': ypred})

predicted_df.to_csv('predicted_values.csv', index=False)


# Download Predicted Binding Affinintes

In [17]:
from google.colab import files
import shutil
files.download('predicted_values.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>