<a href="https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/notebooks/replacement_scan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Replacement Scan

Purpose: Find how many amino acid replacements your protein can tolerate and still make same prediction. Note, this protocol runs in single-sequence only (NO MSA). The analysis is likely only useful for denovo designed proteins.

In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_af_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.utils import copy_dict
num2aa = {b:a for a,b in residue_constants.restype_order.items()}

from IPython.display import HTML
from google.colab import files
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

clear_mem()
af_model = mk_af_model(protocol="fixbb")

In [None]:
#@title run
#@markdown #### Input Options
pdb = "6MRR" #@param {type:"string"}
chain = "A" #@param {type:"string"}
#@markdown #### Model Options
num_recycles = 0 #@param ["0", "1", "2", "3"] {type:"raw"}
#@markdown #### Scan Options
num_tries = -1 #@param [-1,10,20] {type:"raw"}
#@markdown - number of mutations to try before accepting. (`-1` = greedy, try all).
mut_res = "A" # @param ["None", "A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]
#@markdown - which mutation to make (`None` = mask the position completely).
sco_type = "rmsd+dgram" # @param ["rmsd+dgram","rmsd","dgram","pae","plddt","ptm"]
#@markdown - criteria to select by (`dgram` = cce between true and predicted distogram)

def get_sco(aux, sco_type):
  if sco_type == "rmsd":
    loss = aux["log"]["rmsd"]
  elif sco_type == "dgram":
    loss = aux["log"]["dgram_cce"]
  elif sco_type == "pae":
    loss = aux["log"]["pae"] * 31.0
  elif sco_type == "plddt":
    loss = 1 - aux["log"]["plddt"]
  elif sco_type == "ptm":
    loss = 1 - aux["log"]["ptm"]
  else:
    loss = aux["log"]["rmsd"] + aux["log"]["dgram_cce"]
  print_str = f"rmsd={aux['log']['rmsd']:.3f} loss={loss:.3f} plddt={aux['log']['plddt']:.3f} ptm={aux['log']['ptm']:.3f}"
  return loss, print_str

MUT = -1 if mut_res == "None" else residue_constants.restype_order[mut_res]

pdb_filename = get_pdb(pdb)
af_model.prep_inputs(pdb_filename=pdb_filename, chain=chain)
WT = af_model._wt_aatype.copy()
NEW_SEQ = WT.copy()

af_model.predict(seq=NEW_SEQ, verbose=False, hard=False, num_recycles=num_recycles)
af_model._save_results(verbose=False)

AUXS = [copy_dict(af_model.aux)]
RMSDS = [af_model.aux["log"]["rmsd"]]
SEQS = [NEW_SEQ]
loss,print_str = get_sco(af_model.aux, sco_type)
print(f">{af_model._k} {print_str}")
print("".join([num2aa.get(a,"X") for a in SEQS[-1]]))

af_model._k += 1

RMS_MTX = np.full((len(WT),len(WT)),np.nan)

n = 0
while sum(NEW_SEQ != MUT):
  pos = np.where(NEW_SEQ != MUT)[0]
  if num_tries > -1:
    pos = np.random.permutation(pos)[:num_tries]
  buff = []
  losses = []
  for t in pos:
    test_seq = NEW_SEQ.copy()
    test_seq[t] = MUT
    aux = af_model.predict(seq=test_seq, return_aux=True, verbose=False, hard=False, num_recycles=num_recycles)
    RMS_MTX[n,t] = aux["log"]["rmsd"]
    buff.append({"aux":aux, "seq":test_seq})
    losses.append(get_sco(aux, sco_type)[0])

  # accept best
  best_idx = np.argmin(losses)
  best = buff[best_idx]
  NEW_SEQ = best["seq"]
  RMSDS.append(best["aux"]["losses"]["rmsd"])
  SEQS.append(NEW_SEQ)
  AUXS.append(best["aux"])

  af_model.aux = best["aux"]
  af_model.set_seq(seq=NEW_SEQ)
  af_model._save_results(verbose=False)

  print_str = get_sco(af_model.aux, sco_type)[1]
  print(f">{af_model._k} {print_str}")
  print("".join([num2aa.get(a,"X") for a in SEQS[-1]]))
  af_model._k += 1
  n += 1

In [None]:
#@title plot results (optional)

plot_type = "line" # @param ["line", "heatmap"]
dpi = 100 #@param {type:"integer"}

if plot_type == "line":
  # Calculate PTM values
  PLDDT = [aux['log']["plddt"] for aux in AUXS]
  REPLACED = [sum(seq == MUT)/len(seq) for seq in SEQS]

  # Prepare the plot
  fig, ax1 = plt.subplots(figsize=(5, 4), dpi=dpi)

  # Plot RMSD on the primary y-axis
  color_rmsd = 'tab:blue'
  ax1.plot(REPLACED, RMSDS, color=color_rmsd)
  ax1.set_yscale("log")
  ax1.set_ybound(lower=0.5, upper=32)
  ax1.set_yticks([0.5, 1, 2, 4, 8, 16, 32])
  ax1.set_yticklabels([0.5, 1, 2, 4, 8, 16, 32], color=color_rmsd)
  ax1.set_xlim([0, 1])
  ax1.set_xlabel(f"fraction of '{mut_res}'")
  ax1.set_ylabel("rmsd", color=color_rmsd)

  # Create a secondary y-axis for PTM
  ax2 = ax1.twinx()
  color_ptm = 'tab:orange'
  ax2.plot(REPLACED, PLDDT, color=color_ptm)
  ax2.set_ylim(0, 1)
  ax2.set_ylabel("pLDDT", color=color_ptm)

  # Customize the tick parameters for the secondary y-axis
  ax2.tick_params(axis='y', labelcolor=color_ptm)

  # Show the plot
  plt.show()

if plot_type == "heatmap":
  plt.figure(dpi=dpi)
  plt.imshow(RMS_MTX.T, cmap="bwr", norm=colors.LogNorm(vmin=0.5, vmax=32))
  cbar = plt.colorbar(ticks=[0.5, 1, 2, 4, 8, 16, 32])
  cbar.ax.set_yticklabels([str(tick) for tick in [0.5, 1, 2, 4, 8, 16, 32]])
  plt.xlabel("step")
  plt.ylabel("position")
  plt.show()

In [None]:
#@title animate protein (optional)
dpi = 100 #@param {type:"integer"}

HTML(af_model.animate(dpi=dpi))

In [None]:
#@title display protein (optional)
num_muts = -1 #@param {type:"integer"}
#@markdown - Enter index of protein to show (-1 = auto, 0 = no mutations, 10 = 10 mutations, etc)
color = "pLDDT" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
color_HP = False #@param {type:"boolean"}
if num_muts == -1:
  num_muts = np.where((np.array(RMSDS) < RMSDS[0]))[0][-1]

print(f">{num_muts}", get_sco(AUXS[num_muts],sco_type)[1])
print("".join([num2aa.get(a,"X") for a in SEQS[num_muts]]))
af_model.plot_pdb(aux=AUXS[num_muts],
                  show_sidechains=show_sidechains,
                  show_mainchains=show_mainchains,
                  color=color, color_HP=color_HP)