# 1) Install LASErMPNN & Dependencies (Hit `Run all` -> Notebook will automatically restart -> Hit `Run all` again after restart)

In [None]:
# @title
try:
  import torch
  import torch_scatter
  import torch_cluster
  if torch.__version__ != '2.2.0+cu121':
    raise ValueError

  import pykeops
  pykeops.test_numpy_bindings()
  pykeops.test_torch_bindings()

  print('Dependencies installed successfully! Installing LASErMPNN.')
  !git clone https://github.com/polizzilab/LASErMPNN.git --depth 1
except:
  !uv pip install --system 'torch==2.2' scipy 'numpy==1.26.4' pandas scikit-learn h5py pytest prody matplotlib seaborn jupyter plotly pykeops logomaker wandb tqdm rdkit py3Dmol
  !pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.2.0+cu121.html

  print("Installed dependencies into colab. Hit 'Run all' one more time to finish installation.")

  import os
  os.kill(os.getpid(), 9)

# 2) Use LASErMPNN to Design Sequences

In [None]:
# @title Upload an input PDB file to run LASErMPNN on. (See upload prompt below cell after executing)
import os
from google.colab import files
import prody as pr
from pathlib import Path
import shutil

residue_one_letter_codes = {'C', 'D', 'S', 'Q', 'K', 'I', 'P', 'T', 'F', 'N', 'G', 'H', 'L', 'R', 'W', 'A', 'V', 'E', 'Y', 'M', 'X'}

# @markdown # Use [ProDy selection syntax](http://www.bahargroup.org/prody/manual/reference/atomic/select.html) to restrict design to specific residues or chains.
# @markdown - #### To design all protein residues, enter `protein` or leave empty.
# @markdown - #### To restrict design to a single chain (chain A), enter `chid A`
# @markdown - #### To restrict design to two chains (chain A and chain B), enter `(chid A) or (chid B)`
# @markdown - #### To restrict design to certain residue indices (residx 1 to 5), enter `resindex 1 2 3 4 5` or `resindex 1 to 5`
# @markdown - #### To restrict design to region around the ligand, enter `protein within 5.0 of (not protein)`

prody_selection_string = "" # @param {"type":"string","placeholder":"protein"}
if prody_selection_string == '' or prody_selection_string is None:
  prody_selection_string = 'protein'

# @markdown # Comma separated list of amino acid one-letter-codes to prevent LASErMPNN from sampling, ex: Cysteine `C`
disable_residues = "" # @param {"type":"string","placeholder":""}
if disable_residues is None:
  disable_residues = ""

disabled_residues = ['X']
for residue in disable_residues.split(','):
  if residue == '':
    continue
  if len(residue) != 1 or residue not in residue_one_letter_codes:
    raise ValueError(f'Invalid residue in disable_residues: {residue}')
  disabled_residues.append(residue)
disabled_residues_string = ','.join(disabled_residues)

def upload_files():
  upload_dict = files.upload()

  # Get the path to what was uploaded
  pdb_string = upload_dict[list(upload_dict.keys())[0]]
  ipath = (Path('/content/') / list(upload_dict.keys())[0]).absolute()

  # Figure out what type of file it is.
  tmp_opath = Path('/content/temp')
  for suffix in ipath.suffixes:
    tmp_opath = tmp_opath.with_suffix(suffix)

  if ''.join(ipath.suffixes) not in ['.pdb', '.cif', '.pdb.gz', '.cif.gz']:
    raise ValueError(f'File type not supported: {"".join(ipath.suffixes)}')

  shutil.copy(ipath, tmp_opath)
  os.remove(str(ipath.absolute()))

  return tmp_opath

pdb_path = upload_files()

print("Upload a .pdb, .pdb.gz, .cif, or .cif.gz file encoding your protein complex below.")
print(Path(pdb_path).absolute().suffixes)

if '.cif' in Path(pdb_path).suffixes:
  protein = pr.parseMMCIF(str(pdb_path))
else:
  protein = pr.parsePDB(str(pdb_path))

protein.setBetas(1.0)
protein.select(f'same residue as ({prody_selection_string})').setBetas(0.0)
opath = Path(pdb_path)
while opath.suffixes:
  opath = opath.with_suffix('')
pdb_path = opath.with_suffix('.pdb')
pr.writePDB(str(pdb_path), protein)

print(pdb_path)


In [None]:
# @title Run LASErMPNN on a target file.
import subprocess

###########
### Define parameters
num_designs = 15 #@param {type: "integer"}
num_designs = int(num_designs)

sequence_sampling_temp = 0.1 #@param ['0.000001', '0.1', '0.2', '0.3', '0.5'] {type:"raw"}
###########

# TODO: remove previous output if it exists.


command_string = f'cd /content/; python -m LASErMPNN.run_batch_inference {pdb_path.absolute()} /content/output/ {num_designs} --output_fasta --sequence_temp {sequence_sampling_temp} -d cuda:0 --fix_beta --repack_all --disabled_residues {disabled_residues_string}'
print(command_string)

out = subprocess.run(
    command_string, shell=True, capture_output=True
)

print(out.stderr.decode('utf-8'))
print(out.stdout.decode('utf-8'))

with open('/content/output/designs.fasta', 'r') as f:
  print(f.read())

In [None]:
# @title Visualize Generated Designs
import py3Dmol
from pathlib import Path

path_options = sorted(list(Path('/content/output/').glob('*.pdb')))
# @markdown # Select a design index from `0` to `num_designs - 1`.
design_to_visualize = 0 # @param {"type":"integer","placeholder":"0"}
assert design_to_visualize < len(path_options)

path_to_visualize = f'/content/output/design_{design_to_visualize}.pdb'
print('visualizing', path_to_visualize)

view = py3Dmol.view(width=1000)
view.addModel(open(path_to_visualize, 'r').read(),'pdb', keepH=True)
view.setBackgroundColor('grey')

view.setStyle({}, {'cartoon': {'color':'green'}})
# view.setStyle({'chain':'C'}, {'cartoon': {'color':'green'}})


view.addStyle({}, {'stick': {'colorscheme': 'greyCarbon'}})
# view.addStyle({'chain':'C'}, {'stick': {'colorscheme': 'greyCarbon'}})

# view.addStyle({'within':{'distance':'8.0', 'sel':{'resn':'GG2'}}}, {'stick': {'colorscheme': 'greenCarbon'}})
view.addStyle({'resn':'GG2'}, {'stick': {'colorscheme':'cyanCarbon'}})
view.zoomTo()
view.show()

In [None]:
# @title Download LASErMPNN Outputs.
import os

# @markdown ### Select what to download.
download_pdbs = True # @param {"type": "boolean"}
download_fasta = True # @param {"type": "boolean"}

default_zip_output = '/content/lasermpnn_outputs.zip'

if os.path.exists(default_zip_output):
  os.remove(default_zip_output)

if download_pdbs:
  subprocess.run(f'zip -j {default_zip_output} /content/output/*.pdb', shell=True)
  files.download(default_zip_output)

if download_fasta:
  if download_pdbs:
    subprocess.run(f'zip -j {default_zip_output} /content/output/*.fasta', shell=True)
  else:
    files.download('/content/output/designs.fasta')
