# <center>**ProBASS: a language model with sequence and structural features for predicting the effect of mutations on binding affinity**</center>
---
Here we introduce a model (ProBASS) which is fine-tuned, incorporating features derived from both Protein Language models ESM-2 and ESM-IF1.This model is designed for the prediction of ddGbind values, which serve as indicators of both the sequence and structural attributes of the mutated protein complexes.

The model is an efficient way to predict the effect of mutations on protein binding affinity.

---

**Instructions for users on how to provide the PDB ID of the protein complex and the CSV file which contains the mutation information to Probass**

Please input the "PDB ID" of the Protein complex under the subcategory Input Data which is required to calculate the binding affinity of the mutations.

The user can specify the desired mutations for binding affinity calculation by providing the informations in the proper subcategory. This should include  'Mutated_chain', 'Partner_chain', 'Wild_type', 'Position', and 'Mutation'.

The 'Mutated_chain' and 'Partner_chain' define the interface of the protein complex. 'Wild_type' refers to the original amino acid in the protein complex, 'Position' indicates the location of the desired mutation, and 'Mutation' specifies the amino acid the user wishes to substitute for the wild type.

**Instructions for using this Colab notebook**

Two options are possible for uploading the protein complex structure.

1)	**The complex structure is downloaded directly from the PDB**. Please input the "PDB ID" of the Protein complex.


2)	**The complex structure is uploaded from the user’s computer**. To enable users to upload their own complex, kindly remove the comment symbols (#) from all lines in the section labeled "Uploading the complex instead of PDB ID". Once uncommented, the user can upload their desired complex upon execution. **Before execution of the program**, The file that you are uploading should be named as a pdb file: 4 letter code with a pdb extension (for example, 3OTJ.pdb). The same pdb file should be specified below under PDB ID.

# Environment Set up for **ProBASS:**

In [None]:
%%capture
!pip install torch-geometric
!pip install biotite==0.33.0
!pip install catboost
!pip install git+https://github.com/facebookresearch/esm.git
!pip install requests
!pip install biopython

In [None]:
import os
import numpy as np
import pandas as pd
import catboost as cb
import torch
import esm
import scipy
from numpy import asarray
from numpy import savez_compressed
import requests
from Bio.PDB import PDBParser

In [None]:
%%bash

cd /content/

if [ ! -f ProBASS ]; then


    # delete the Cold-scanner/ directory if it already exists
    if [ -d "ProBASS/" ]; then
        rm -rf ProBASS/
    fi

    # download model
    git clone https://github.com/sagagugit/ProBASS --quiet
    touch ProBASS
fi

#Input Data

In [None]:
# import sys
# from contextlib import redirect_stdout

# try:
#     from google.colab import drive

#     with redirect_stdout(open(os.devnull, 'w')):
#         drive.mount('/content/drive')

#     from google.colab import files


#     print("Please upload the .pdb file")


#     uploaded = files.upload()
# except FileNotFoundError:
#     print("ERROR: \n Uploading was not successful. Please restart and try to upload the complex again.")



In [None]:
#@title PDB ID
import os
import sys
from google.colab import drive, files
import contextlib
from IPython.display import display, HTML

PDB = '3OTJ' #@param {type:"string"}
Mutated_chain = 'I' #@param {type:"string"}
Partner_chain = 'E' #@param {type:"string"}
Wild_type = 'T' #@param {type:"string"}
Position = 11 #@param {type:"integer"}
Mutation = 'P' #@param {type:"string"}

pdb_file_path = f'/content/{PDB}.pdb'




#Selecting Path

In [None]:
%%capture
%cd ProBASS
!cp /content/Input.csv /content/ProBASS

# Extracting embeddings from ESM2 and ESM-IF1

# Extracting Fasta files for wild type, partner chain and mutated PPI

In [None]:
%%capture

PDB_code = PDB

url = f'http://www.rcsb.org/pdb/download/downloadFile.do?fileFormat=pdb&structureId={PDB_code}'
response = requests.get(url)

if response.status_code == 200:
    with open(f'{PDB_code}.pdb', 'wb') as file:
        file.write(response.content)
    print(f'{PDB_code}.pdb has been downloaded successfully.')
else:
    print(f'Failed to download {PDB_code}.pdb. Status code: {response.status_code}')

In [None]:
from Bio.PDB import PDBParser

RESIDUE_NAME_TO_LETTER = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}

try:
    PDB_code = PDB
    pdb_file = f'{PDB_code}.pdb'
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(PDB_code, pdb_file)

    def extract_sequence_and_check_gaps(chain_id):
        sequence = []
        start_residue_number = None
        previous_resnum = None  # Track the previous residue number to check for gaps
        for model in structure:
            for chain in model:
                if chain.get_id() == chain_id:
                    for residue in chain:
                        # Exclude heteroatoms (HETATM entries)
                        if residue.get_id()[0] != ' ':
                            continue

                        resname = residue.get_resname()
                        resnum = residue.get_id()[1]  # Extract the residue number

                        # Check for gaps in the main chain
                        if previous_resnum is not None and resnum != previous_resnum + 1:
                            raise ValueError(f"Chain {chain_id} in {PDB_code} has a gap between residues {previous_resnum} and {resnum}.")

                        if start_residue_number is None:
                            start_residue_number = resnum
                        if resname in RESIDUE_NAME_TO_LETTER:
                            sequence.append(RESIDUE_NAME_TO_LETTER[resname])
                        previous_resnum = resnum  # Update the previous residue number
        return ''.join(sequence), start_residue_number

    def adjust_positions(mutated_chain_id, position):
        _, start_residue = extract_sequence_and_check_gaps(mutated_chain_id)
        return position - start_residue + 1

    def apply_mutation(sequence, position, new_residue):
        sequence_list = list(sequence)
        sequence_list[position - 1] = new_residue
        return ''.join(sequence_list)

    # Using the provided inputs directly
    mutated_chain_id = Mutated_chain
    partner_chain_id = Partner_chain
    mutation_position = Position
    new_residue = Mutation.upper()  # Ensure the mutation is uppercase

    # Extract sequences and check for gaps in the main chain
    mutated_sequence, mutated_start_residue = extract_sequence_and_check_gaps(mutated_chain_id)
    partner_sequence, _ = extract_sequence_and_check_gaps(partner_chain_id)

    # Adjust mutation position and apply mutation
    adjusted_position = adjust_positions(mutated_chain_id, mutation_position)
    mutated_sequence = apply_mutation(mutated_sequence, adjusted_position, new_residue)

    # Write sequences to FASTA files
    with open(f'{PDB_code}_wild.fasta', 'w') as f:
        f.write(f'> {PDB_code}_wild\n{mutated_sequence}\n')

    with open(f'{PDB_code}_partner.fasta', 'w') as f:
        f.write(f'> {PDB_code}_partner\n{partner_sequence}\n')

except ValueError as ve:
    print("\033[1mERROR MESSAGE:!!!\033[0m\nThe PDB file contains broken chains.")
    print(f"Details: {ve}")
except Exception as e:
    print("\033[1mERROR MESSAGE:!!!\033[0m\nPlease verify that the inputs are properly formatted and that the mutation information is accurate.")
    print(f"Exception details: {e}")



In [None]:
PDB_code = PDB

# Helper function to read a FASTA file
def read_fasta(fasta_file):
    """Reads a FASTA file and returns the sequence."""
    with open(fasta_file, 'r') as f:
        lines = f.readlines()
    return ''.join(line.strip() for line in lines[1:])

# Load the wild-type sequence from the FASTA file
wild_sequence = read_fasta(f'{PDB_code}_wild.fasta')

# Helper function to extract residue numbers from a PDB file
def extract_residue_numbers(pdb_file, chain_id):
    """Extracts residue numbers for a specific chain from a PDB file."""
    from Bio.PDB import PDBParser
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(PDB_code, pdb_file)

    residue_numbers = []
    for model in structure:
        for chain in model:
            if chain.get_id() == chain_id:
                for residue in chain:
                    residue_numbers.append(residue.get_id()[1])
    return residue_numbers

# Helper function to map a PDB residue position to its index in the sequence
def pdb_position_to_index(pdb_residue_numbers, pdb_position):
    """Finds the sequence index for a given PDB residue position."""
    try:
        return pdb_residue_numbers.index(pdb_position)
    except ValueError:
        print(f"Warning: Position {pdb_position} not found in the PDB file.")
        return None

# Extract residue numbers for the mutated chain
pdb_residue_numbers = extract_residue_numbers(f'{PDB_code}.pdb', Mutated_chain)

# Prepare the mutated sequence
mutated_sequence = list(wild_sequence)

# Apply the mutation
idx = pdb_position_to_index(pdb_residue_numbers, Position)
if idx is not None and 0 <= idx < len(mutated_sequence):
    mutated_sequence[idx] = Mutation.upper()  # Ensure mutation is uppercase
    mutated_sequence_str = ''.join(mutated_sequence)

    # Prepare FASTA entry
    fasta_header = f'> {PDB_code}_{Position}{Mutation.upper()}\n'
    fasta_entry = fasta_header + mutated_sequence_str + '\n'

    # Write to a FASTA file
    with open(f'{PDB_code}.fasta', 'w') as f:
        f.write(fasta_entry)
else:
    print("ERROR: Mutation position is out of bounds or not found in the PDB residue numbers.")




# Extract sequence embeddings and Structural embeddings

In [None]:
%%capture

# Define the PDB code
PDB_code = PDB

# Extract embeddings using `esm2_t33_650M_UR50D` for the given FASTA files
!python extract.py esm2_t33_650M_UR50D {PDB}.fasta {PDB}_esm2 --repr_layers 0 32 33 --include mean per_tok
!python extract.py esm2_t33_650M_UR50D {PDB}_wild.fasta {PDB}_esm2_wild --repr_layers 0 32 33 --include mean per_tok
!python extract.py esm2_t33_650M_UR50D {PDB}_partner.fasta {PDB}_esm2_partner --repr_layers 0 32 33 --include mean per_tok

# Load the pretrained ESM model for inverse folding
import esm
import numpy as np
from esm.inverse_folding.util import load_structure
from esm.inverse_folding.multichain_util import (
    extract_coords_from_complex, get_encoder_output_for_complex
)

model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()  # Set the model to evaluation mode

# Load the structure using the PDB file
fpath = f"{PDB_code}.pdb"
chain_ids = [Mutated_chain, Partner_chain]

structure = load_structure(fpath, chain_ids)

# Extract coordinates and native sequences for the chains
coords, native_seqs = extract_coords_from_complex(structure)

print(f"Loaded chains: {list(coords.keys())}\n")

# Print native sequences for the mutated and partner chains
for chain_id in chain_ids:
    print(f"Chain {chain_id} native sequence:")
    print(native_seqs[chain_id])
    print("\n")

# Generate encoder output for the mutated chain
rep = get_encoder_output_for_complex(model, alphabet, coords, Mutated_chain)

print(f"Shape of encoder output for chain {Mutated_chain}: {rep.shape}")

# Save the representation as a NumPy file
numpy_rep = rep.detach().numpy()
np.savez(f"inverse_{PDB_code}.npz", data=numpy_rep)


# Run ProBASS to predict the ΔΔG values

In [None]:
PDBS = PDB_code = [PDB]

def exctracting_embeddings_esm2(pdb):
    mutations2= []
    Xs2 = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2):
        scaled_effect2 = header2.split('|')[-1]
        mutations2.append(scaled_effect2)
        fn = f'{EMB_PATH2}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2.append(embs2['representations'][33])
    Xs2 = torch.stack(Xs2, dim=0).numpy()

    return Xs2, mutations2





def exctracting_embeddings_esm2_wild(pdb):
    mutations2_w= []
    Xs2_w = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2_w):
        scaled_effect2_w = header2.split('|')[-1]
        mutations2_w.append(scaled_effect2_w)
        fn = f'{EMB_PATH2_w}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2_w.append(embs2['representations'][33])
    Xs2_w = torch.stack(Xs2_w, dim=0).numpy()

    return Xs2_w




def exctracting_embeddings_esm2_bind(pdb):
    mutations2_b= []
    Xs2_b = []
    for header2, _seq2 in esm.data.read_fasta(FASTA_PATH2_b):
        scaled_effect2_b = header2.split('|')[-1]
        mutations2_b.append(scaled_effect2_b)
        fn = f'{EMB_PATH2_b}/{header2[1:]}.pt'
        embs2 = torch.load(fn)
        Xs2_b.append(embs2['representations'][33])
    Xs2_b = torch.stack(Xs2_b, dim=0).numpy()

    return Xs2_b

def exctracting_embeddings_1f(pdb):
    temp= np.load(inverse_path)
    inverse= temp['data']


    average_mean_embedding = np.mean(inverse, axis=0)
    average_mean_embedding.shape
    inverse_mean_reshape = average_mean_embedding.reshape([1, 512])
    inverse_mean_reshape.shape




    return inverse_mean_reshape

In [None]:
%%capture
ddg_values = []
embeddings = []
for pdb in PDBS:
    FASTA_PATH = "/content/ProBASS/{}.fasta".format(pdb)
    EMB_PATH = "/content/ProBASS/{}_1V".format(pdb)
    FASTA_PATH2 = "/content/ProBASS/{}.fasta".format(pdb)
    EMB_PATH2 = "/content/ProBASS/{}_esm2".format(pdb)
    FASTA_PATH_w = "/content/ProBASS/{}_wild.fasta".format(pdb)
    EMB_PATH_w = "/content/ProBASS/{}_1V_wild".format(pdb)
    FASTA_PATH2_w = "/content/ProBASS/{}_wild.fasta".format(pdb)
    EMB_PATH2_w = "/content/ProBASS/{}_esm2_wild".format(pdb)
    FASTA_PATH_b = "/content/ProBASS/{}_partner.fasta".format(pdb)
    EMB_PATH_b = "/content/ProBASS/{}_1V_partner".format(pdb)
    FASTA_PATH2_b = "/content/ProBASS/{}_partner.fasta".format(pdb)
    EMB_PATH2_b = "/content/ProBASS/{}_esm2_partner".format(pdb)
    inverse_path = '/content/ProBASS/inverse_{}.npz'.format(pdb)
    csv_path = "/content/ProBASS/{}.csv".format(pdb)
    Xs2, mutations2= exctracting_embeddings_esm2(pdb)
    Xs2_w= exctracting_embeddings_esm2_wild(pdb)
    Xs2_w=np.tile(Xs2_w, (len(Xs2), 1, 1))
    Xs2_b=exctracting_embeddings_esm2_bind(pdb)
    Xs2_b=np.tile(Xs2_b, (len(Xs2), 1, 1))
    inverse=exctracting_embeddings_1f(pdb)
    inverse=np.tile(inverse, (len(Xs2), 1))
    mutant_and_partner_together_esm2 = np.concatenate([Xs2_b, Xs2], axis =1)

    wild_type_and_partner_together_esm2 = np.concatenate([Xs2_b, Xs2_w], axis =1)
    mutant_and_partner_together_esm2_mean=np.mean(mutant_and_partner_together_esm2, axis=1)
    wild_type_and_partner_together_esm2_mean=np.mean(wild_type_and_partner_together_esm2, axis=1)
    ddg_1v = np.subtract(mutant_and_partner_together_esm2_mean, wild_type_and_partner_together_esm2_mean)

    ddg_esm2_with_inverse = np.concatenate([ddg_1v, inverse], axis =1)
    embeddings.append(ddg_esm2_with_inverse)

In [None]:
import numpy as np


ddg_length = len(embeddings[0])
ddg_values = [0] * ddg_length


flattened_list = ddg_values


extracted_array = embeddings[0]
Xs_test = extracted_array
ys_test = flattened_list

np.savez('test.npz', data=Xs_test, label=ys_test)

In [None]:
temp = np.load('test.npz')
X_test, test_y = temp['data'], temp['label']

In [None]:
import pandas as pd
import catboost as cb

# Assuming the following inputs are already defined in the previous cells:
# `Wild_type`, `Position`, `Mutation`, and `X_test` (test features for the model).

# Ensure that `Wild_type`, `Position`, and `Mutation` are defined
try:
    # Combine inputs to create the mutation string
    mutation_string = Wild_type + str(Position) + Mutation.upper()

    # Load the trained CatBoost model
    model = cb.CatBoostRegressor()
    model.load_model('Probass_model.cbm')

    # Generate predictions using the model
    ypred = model.predict(X_test)

    # Prepare a DataFrame with mutation information and predicted values
    predicted_df = pd.DataFrame({
        'Mutation': [mutation_string],  # Single mutation string
        'predicted_value ΔΔG kcal/mol': ypred
    })

    # Save the predictions to a CSV file
    predicted_df.to_csv('predicted_values.csv', index=False)

except NameError as e:
    print("\033[1mERROR MESSAGE:!!!\033[0m")
    print(f"Missing input: {e}. Ensure all required variables are defined in previous cells.")
except Exception as e:
    print("\033[1mERROR MESSAGE:!!!\033[0m")
    print(f"An unexpected error occurred: {e}")



# Download Predicted Binding Affinintes

In [None]:
from google.colab import files
import shutil
files.download('predicted_values.csv')