<a href="https://colab.research.google.com/github/prathithbhargav/AlphaMut/blob/master/3_inference_of_Helix-in-protein_trained_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installation

In [2]:
%%capture
#@title Installing Required Packages
!pip install stable-baselines3
!pip install biotite==0.41
!pip install biopandas
!pip install py3dmol
#@markdown Press the play button to install the dependencies and run the function
!pip install py3dmol
!pip install biopython
!pip install mdanalysis
import pandas as pd
import os, glob
from biopandas.pdb import PandasPdb
import MDAnalysis as mda
from MDAnalysis.analysis import align
from MDAnalysis.analysis.rms import rmsd
def get_sequence_from_pdb(pdb_file,chain_id):
  ppdb = PandasPdb().read_pdb(pdb_file)
  sequence = ppdb.amino3to1()
  sequence = ''.join(sequence.loc[sequence['chain_id'] == chain_id, 'residue_name'])
  return sequence
def get_helix_sequence_from_pdb(pdb_file,chain_id,starting_residue,ending_residue):
  ppdb = PandasPdb().read_pdb(pdb_file)
  atom_df = ppdb.df['ATOM']
  helix_atom_df = atom_df[(atom_df['residue_number'] >= starting_residue) & (atom_df['residue_number'] <= ending_residue)]
  ppdb.df['ATOM'] = helix_atom_df
  sequence = ppdb.amino3to1()
  sequence_str = ''.join(sequence.loc[sequence['chain_id'] == chain_id, 'residue_name'])
  return sequence_str

def find_start_end(seq, whole_seq):
    start_idx = whole_seq.find(seq)
    if start_idx != -1:
        end_idx = start_idx + len(seq) - 1
    else:
        start_idx = None
        end_idx = None
    return start_idx, end_idx


def find_point_mutations_with_file(wt_csv_file, file_path,starting_residue_specified_by_user):
    wt = pd.read_csv(wt_csv_file)['Seq'].values[0]
    starting_residue = pd.read_csv(wt_csv_file)['starting_residue'].values[0]
    ending_residue =  pd.read_csv(wt_csv_file)['ending_residue'].values[0]

    mut = get_sequence_from_pdb(file_path,chain_id='A')
    mut = mut[starting_residue:ending_residue+1]
    base_number = starting_residue_specified_by_user

    # print(wt,mut)
    if len(wt) != len(mut):
        raise ValueError("Sequences must be of the same length")

    mutations = []

    # Iterate over the sequences and compare each position
    for i in range(len(wt)):
        if wt[i] != mut[i]:
            # Adjust the mutation position by adding the base number minus 1
            mutation = f"{wt[i]}{i + 1 + (base_number - 1)}{mut[i]}"
            mutations.append(mutation)

    # Number of point mutations is the length of the mutations list
    num_mutations = len(mutations)

    return num_mutations, mutations


def align_and_save_pdb_structures(mutant_pdb_file_path,folder_to_save_aligned_files):
  ref = mda.Universe(f'{protein_name}.pdb')
  mobile = mda.Universe(mutant_pdb_file_path)
  # Select Cα atoms from both structures and match by residue number
  ref_atoms = ref.select_atoms('protein and name CA')
  mobile_atoms = mobile.select_atoms('protein and name CA')

  # Check for matching residues
  ref_residues = ref_atoms.residues.resids
  mobile_residues = mobile_atoms.residues.resids
  common_residues = set(ref_residues).intersection(mobile_residues)

  # Select only the common residues for alignment
  ref_aligned_atoms = ref.select_atoms(f"resid {' '.join(map(str, common_residues))} and name CA")
  mobile_aligned_atoms = mobile.select_atoms(f"resid {' '.join(map(str, common_residues))} and name CA")
  ref_residues = ref_atoms.residues.resids
  mobile_residues = mobile_atoms.residues.resids
  common_residues = set(ref_residues).intersection(mobile_residues)

  # Select only the common residues for alignment
  ref_aligned_atoms = ref.select_atoms(f"resid {' '.join(map(str, common_residues))} and name CA")
  mobile_aligned_atoms = mobile.select_atoms(f"resid {' '.join(map(str, common_residues))} and name CA")
  mobile0 = mobile_aligned_atoms.positions - mobile_aligned_atoms.center_of_mass()
  ref0 = ref_aligned_atoms.positions - ref_aligned_atoms.center_of_mass()

  R, rmsd = align.rotation_matrix(mobile0, ref0)
  mobile.atoms.translate(-mobile_aligned_atoms.center_of_mass())
  mobile.atoms.rotate(R)
  mobile.atoms.translate(ref_aligned_atoms.center_of_mass())

  # Save the aligned mutant structure
  base_name_mutant = os.path.basename(mutant_pdb_file_path).split('.pdb')[0]
  os.makedirs(folder_to_save_aligned_files, exist_ok=True)
  mobile.atoms.write(f"{folder_to_save_aligned_files}/{base_name_mutant}_aligned.pdb")









In [1]:
%%capture
!unzip Helix_in_protein.zip

# Choosing the protein and helix to disrupt

In [3]:
%%capture
import py3Dmol
pdb_id = '1do4' #@param {type:"string"}
#@markdown  - make sure to use a protein with a helix and a sequence length < 250
chain_id = "A" #@param {type:"string"}
#@markdown - specify the chain
protein_name = 'Myoglobin' #@param {type:"string"}
#@markdown - specify a name for the protein, helpful, but not required.
# number of iterations
num_iterations = 0 #@param {type:"raw"}
#@markdown - specify how many iterations to run the model (does not mean you will get that many successful disruptions)
max_mutations =7#@param {type:'raw'}
#@markdown - specify the maximum mutations that you want in the helix. Note that the model is trained for 13 maximum mutations.
!wget https://files.rcsb.org/download/{pdb_id}.pdb -O {protein_name}.pdb
starting_residue = 82 #@param {type:"raw"}
#@markdown - specify the starting residue number of the helix, according to the PDB numbering
ending_residue = 96 #@param {type:"raw"}
#@markdown - specify the ending residue number of the helix, according to the PDB numbering









In [4]:
#@title Displaying the helix to disrupt
#@markdown The helix displayed in magenta is the one we are going to disrupt
with open(f"{protein_name}.pdb") as ifile:
    system = "".join([x for x in ifile]) # adopted from https://william-dawson.github.io/using-py3dmol.html
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(system)
view.setStyle({'model': -1}, {"cartoon": {'color': 'teal','opacity': 0.7}})
resid ={'resi':f'{starting_residue}-{ending_residue}'}
view.setStyle(resid, {"cartoon": {'color': 'magenta'}})

view.zoomTo()
view.show()

In [5]:
#@title Creating the input file for the model
os.makedirs('input_csv_files',exist_ok=True)
helix_sequence = get_helix_sequence_from_pdb(f"{protein_name}.pdb",chain_id, starting_residue, ending_residue)
protein_sequence = get_sequence_from_pdb(f"{protein_name}.pdb",chain_id)
start, end = find_start_end(helix_sequence, protein_sequence)
dict_for_df = {'PDB':f'{protein_name}', "Seq":helix_sequence, "SeqLen":len(helix_sequence),  "pdb_id":pdb_id, "chain_id":chain_id,  "whole_protein_sequence":protein_sequence,  "length_of_whole_sequence":len(protein_sequence),  "starting_residue":start+1,  "ending_residue":end+1,  "dataset":"validate"}
pd.DataFrame(dict_for_df,index=[0]).to_csv(f'input_csv_files/{protein_name}.csv')

In [9]:
#@title Running the inference model functions
import Helix_in_protein
import pandas as pd
import numpy as np
import glob, os, shutil, pickle
import gymnasium as gym
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO

def evaluate_model(estimator, env, num_episodes):
    total_reward_episode = [0] * num_episodes
    actions_taken_in_episodes = {}
    number_of_mutations_per_episode = [0] * num_episodes
    file_chosen_for_mutations = {}
    for episode in tqdm(range(num_episodes)):
        actions_taken_in_episodes[episode] = []
        file_chosen_for_mutations[episode] = [0]
        rewards = []

        state, info = env.reset()
        while True:
            action= estimator.predict(state)[0]
            next_state, reward, terminated, truncated, info = env.step(action)
            total_reward_episode[episode] += reward
            rewards.append(reward)
            name_of_the_protein = os.path.basename(env.path_of_template_pdb_file),
            actions_taken_in_episodes[episode].append([info['amino_acid_position'],info['new_amino_acid']])
            number_of_mutations_per_episode[episode]+=1
            file_chosen_for_mutations[episode][0] = str(name_of_the_protein[0])
            if terminated or truncated:
                break
            state = next_state
#         print(f'{episode} episode done')
    return total_reward_episode,np.mean(total_reward_episode),np.std(total_reward_episode),actions_taken_in_episodes,number_of_mutations_per_episode,file_chosen_for_mutations
def validate_model_disease(saved_model_file_path,csv_file_path,maximum_mutations_allowed,number_of_iterations):
    base_name_of_saved_model = os.path.basename(saved_model_file_path)
    # creating the folder to save structures
    number_of_episodes_to_run_validation = number_of_iterations
    csv_file_base_path = os.path.basename(csv_file_path).split('.')[0]
    base_name_of_saved_model = base_name_of_saved_model
    os.makedirs(f'{protein_name}_disrupted',exist_ok=True)
    algorithm_used_for_training = 'PPO'

    env = Helix_in_protein.ProteinEvolution(file_containing_sequence_database=csv_file_path,
                                        protein_length_limit=250,
                                          folder_to_save_validation_files=f'{protein_name}_disrupted',
                                          reward_cutoff=30,
                                          unique_path_to_give_for_file='validation_try',
                                          sequence_encoding_type='esm',
                                          maximum_number_of_allowed_mutations_per_episode=maximum_mutations_allowed,
                                          use_proline=False,
                                               use_plddt_in_reward=False,
                                          validation=True)

    # folder_for_saving_stuff = f'disrupted_helices/results_obtained_{csv_file_base_path}_{maximum_mutations_allowed}_only'
    # os.makedirs(folder_for_saving_stuff,exist_ok=True)
    loaded_estimator = PPO.load(saved_model_file_path)


    total_reward_array,mean_validation_reward, standard_deviation,actions_taken, mutations_array, files_chosen_for_mutations = evaluate_model(estimator=loaded_estimator, env=env,num_episodes=number_of_episodes_to_run_validation)
    print(mutations_array)


In [10]:
#@title Running the model
import warnings
warnings.filterwarnings("ignore")
validate_model_disease('AlphaMut1.0',
                       csv_file_path=f'input_csv_files/{protein_name}.csv',
                       maximum_mutations_allowed=7,
                       number_of_iterations=10)

  0%|          | 0/10 [00:00<?, ?it/s]

[5, 4, 7, 7, 7, 5, 7, 7, 7, 7]


In [11]:
#@title Showing the disrupted structures

for file in glob.glob(f"{protein_name}_disrupted/*.pdb"):
  align_and_save_pdb_structures(file,f"{protein_name}_aligned")
number_of_structures = len(os.listdir(f"{protein_name}_aligned"))
grid_size = int(number_of_structures**0.5) + 1
view = py3Dmol.view(width=600, height=600,viewergrid = (grid_size,grid_size))
list_of_aligned_structures = glob.glob(f'{protein_name}_aligned/*.pdb')
for i in range(grid_size):
    for j in range(grid_size):
      index = i * grid_size + j
      if index >= number_of_structures:
        wt_i= i
        wt_j= j


        break

      with open(list_of_aligned_structures[index]) as ifile:
        system = "".join([x for x in ifile])
      view.addModelsAsFrames(system,viewer=(i,j))
      view.setStyle({'model': -1}, {"cartoon": {'color': 'teal','opacity': 0.7}},viewer=(i,j))
      resid ={'resi':f'{start}-{end}'}
      view.setStyle(resid, {"cartoon": {'color': 'magenta'}},viewer=(i,j))
      mutations = find_point_mutations_with_file(wt_csv_file=f'input_csv_files/{protein_name}.csv',
                                                 file_path=list_of_aligned_structures[index],
                                                 starting_residue_specified_by_user=starting_residue)[1]
      residue_id_to_label = start + 1
      mutations = str(mutations).strip('[')
      mutations = mutations.strip(']')

      if index < number_of_structures:
        view.addLabel(f"{mutations}",{'fontColor':'black','backgroundColor':'orange','fontSize':10,'backgroundOpacity':0.8},{'resi':residue_id_to_label},viewer=(i,j))
      # view.clear()
      view.zoomTo()
with open(f"{protein_name}.pdb") as ifile:
    system = "".join([x for x in ifile]) # adopted from https://william-dawson.github.io/using-py3dmol.html
# view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(system,viewer=(wt_i,wt_j))
view.setStyle({'model': -1}, {"cartoon": {'color': 'teal','opacity': 0.7}},viewer=(wt_i,wt_j))
resid ={'resi':f'{starting_residue}-{ending_residue}'}
view.setStyle(resid, {"cartoon": {'color': 'magenta'}},viewer=(wt_i,wt_j))
view.addLabel(f"WT {pdb_id}",{'fontColor':'black','backgroundColor':'orange','fontSize':10,'backgroundOpacity':0.8},{'resi':residue_id_to_label},viewer=(wt_i,wt_j))

view.zoomTo()
view.show()