<a href="https://colab.research.google.com/github/smart111/CryoEM-scripts/blob/master/BioEmu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Biomolecular Emulator (BioEmu) in ColabFold**
<img src="https://github.com/microsoft/bioemu/raw/main/assets/emu.png" height="130" align="right" style="height:240px">

[BioEmu](https://github.com/microsoft/bioemu) is a framework for emulating biomolecular dynamics and integrating structural prediction tools to accelerate research in structural biology and protein engineering. This notebook uses BioEmu with ColabFold to generate the MSA and identify cluster conformations using Foldseek.



For more details, please read the [BioEmu Preprint](https://www.biorxiv.org/content/10.1101/2024.12.05.626885v2).


## To run
Either run each cell sequentially, or click on `Runtime -> Run All` after choosing the desired sampling config

In [1]:
#@title Sample with following config
#@markdown - `sequence`: Monomer sequence to sample
sequence = " MITRSRCRRSLLWFLVFHGGATATGAPSGGKELSQTPTWAVAVVCTFLILISHLLEKGLQRLANWLWKKHKNSLLEALEKIKAELMILGFISLLLTFGEPYILKICVPRKAALSMLPCLSEDTVLFQKLAPSSLSRHLLAAGDTSINCKQGSEPLITLKGLHQLHILLFFLAIFHIVYSLITMMLSRLKIRGWKKWEQETLSNDYEFSIDHSRLRLTHETSFVREHTSFWTTTPFFFYVGCFFRQFFVSVERTDYLTLRHGFISAHLAPGRKFNFQRYIKRSLEDDFKLVVGISPVLWASFVIFLLFNVNGWRTLFWASIPPLLIILAVGTKLQAIMATMALEIVETHAVVQGMPLVQGSDRYFWFDCPQLLLHLIHFALFQNAFQITHFFWIWYSFGLKSCFHKDFNLVVSKLFLCLGALILCSYITLPLYALVTQMGSHMKKAVFDENVQVGLVGWAQKVKQKRDLKAAASNGDEGSSQAGPGPDSGSGSAPAAGPGAGFAGIQLSRVTRNNAGDTNNEITPDHNNMADQLTDDQISEFKEAFSLFDKDGDGCITTKELGTVMRSLGQNPTEAELQDMINEVDADG NGTIDFPEFLNLMARKMKDTDSEEELKEAFRVFDKDQNGFISAAELRHVMTNLGEKLTDE EVDEMIREADVDGDGQINYEEFVKVMMAK"  #@param {type:"string"}
#@markdown - `num_samples`: Number of samples requested
num_samples = 10  #@param {type:"integer"}
#@markdown - `jobname`: Name assigned to this job
jobname = "ML1_cam7"  #@param {type:"string"}
#@markdown - `filter_samples`: Whether to filter unphysical samples (e.g., those containing chain breaks) from the written samples
filter_samples = True #@param {type:"boolean"}
# #@param {type:"boolean"}
#@markdown - `model_name`: v1.1 = Journal accepted version | v1.0 preprint version
model_name = "bioemu-v1.1" #@param ["bioemu-v1.0", "bioemu-v1.1"]
# ------------------------
# Copied logic from ColabFold
# ------------------------
import os
import re
import hashlib

def add_hash(x, seq):
    """Append a short SHA-1 hash of seq to x."""
    return x + "_" + hashlib.sha1(seq.encode()).hexdigest()[:5]

def folder_is_free(folder):
    """Return True if folder doesn't exist."""
    return not os.path.exists(folder)

jobname_clean = re.sub(r'\W+', '', jobname)
sequence = "".join(sequence.split())
jobname = add_hash(jobname_clean, sequence)

if not folder_is_free(jobname):
    n = 0
    while not folder_is_free(f"{jobname}_{n}"):
        n += 1
    jobname = f"{jobname}_{n}"

output_dir = os.path.join("/content", jobname)
os.makedirs(output_dir, exist_ok=True)


In [None]:
#@title Install dependencies
import os
import sys

_is_bioemu_setup_file = '/content/.BIOEMU_SETUP'

conda_prefix = '/usr/local/'
miniconda_link = 'https://repo.anaconda.com/miniconda/Miniconda3-py311_25.5.1-1-Linux-x86_64.sh'
miniconda_basename = os.path.basename(miniconda_link)
os.makedirs(conda_prefix, exist_ok=True)

if not os.path.exists(_is_bioemu_setup_file):
  os.system(f'wget {miniconda_link}')
  os.system(f'chmod +x {miniconda_basename}')
  os.system(f'./{miniconda_basename} -b -f -p {conda_prefix}')
  os.system(f'conda install -q -y --prefix {conda_prefix} python=3.11')
  os.system('uv pip install --prerelease if-necessary-or-explicit bioemu')

  os.system('conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main')
  os.system('conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r')
  os.system('conda install -c conda-forge openmm cuda-version=11 --yes')

  sys.path.append(os.path.join(conda_prefix, 'lib/python3.11/site-packages/'))

  os.environ['CONDA_PREFIX'] = conda_prefix
  os.environ['CONDA_PREFIX_1'] = os.path.join(conda_prefix, 'envs/myenv')
  os.environ['CONDA_DEFAULT_ENV'] = 'base'
  os.system(f"touch {_is_bioemu_setup_file}")
  os.system('wget https://mmseqs.com/foldseek/foldseek-linux-avx2.tar.gz; tar xvzf foldseek-linux-avx2.tar.gz')
  os.system('/usr/bin/python3 -m pip install uv')
  os.unlink(miniconda_basename)



In [None]:
#@title Run BioEmu

from bioemu.sample import main as sample
output_dir = f'/content/{jobname}'
sample(sequence=sequence, num_samples=num_samples, model_name=model_name, output_dir=output_dir, filter_samples=filter_samples)

In [None]:
#@title Write samples and run `foldseek`
#@markdown - `n_write_samples`: Number of samples to randomly select for clustering. Set to `-1` to select all available samples
#@markdown - `tmscore_threshold`: TM-score threshold used for foldseek clustering
#@markdown - `coverage_threshold`: Coverage threshold used for foldseek clustering
#@markdown - `seq_id`: Sequence identity threshold used for foldseek clustering

n_write_samples = -1 #@param {type:"integer"}
tmscore_threshold = 0.6 #@param {type: "number"}
coverage_threshold = 0.7 #@param {type: "number"}
seq_id = 0.95 #@param {type: "number"}

import numpy as np
import mdtraj

_py3dmol_installed_file = '/content/.py3dmol'
if not os.path.exists(_py3dmol_installed_file):
    os.system('uv pip install py3Dmol')
    os.system(f"touch {_py3dmol_installed_file}")

import py3Dmol
pdb_sample_dir = os.path.join('/content', 'pdb_samples')
os.makedirs(pdb_sample_dir, exist_ok=True)

def write_some_samples(topology_file: str, trajectory_file: str, output_dir:str, n_samples: int) -> None:
    traj = mdtraj.load(trajectory_file, top=topology_file)
    assert traj.n_frames >= n_samples
    if n_samples == -1:
        sample_indices = np.arange(traj.n_frames)
    else:
        sample_indices = np.random.choice(np.arange(traj.n_frames), size=n_samples, replace=False)
    for idx in sample_indices:
        traj[idx].save_pdb(os.path.join(output_dir, f'sample_{idx}.pdb'))


topology_file = os.path.join(output_dir, "topology.pdb")
trajectory_file = os.path.join(output_dir, "samples.xtc")

write_some_samples(topology_file=topology_file,
                   trajectory_file=trajectory_file,
                   output_dir=pdb_sample_dir,
                   n_samples=n_write_samples)

# Foldseek
import os
import subprocess
import tempfile

import pandas as pd

def parse_foldseek_cluster_results(cluster_table_path: str) -> dict[int, list[str]]:
    """
    Parses the result of foldseek clustering

    Args:
        cluster_table: path of the output cluster table from foldseek

    Returns:
        Dictionary mapping cluster indices to members

    """

    cluster_table = pd.read_csv(cluster_table_path, sep=r"\s+", header=None)

    cluster_idx_to_members = {}

    for index, group in enumerate(cluster_table.groupby(0)):
        cluster_idx_to_members[index] = sorted(list(group[1][1]))

    return cluster_idx_to_members


def foldseek_cluster(
    input_dir: str,
    out_prefix: str | None = None,
    tmscore_threshold: float = 0.7,
    coverage_threshold: float = 0.9,
    seq_id: float = 0.7,
    coverage_mode: int = 1,
) -> dict[int, set[str]]:
    """
    Runs foldseek easy cluster

    Args:
        input_dir (str): input directory with .cif or .pdb files
        out_prefix (str | None): the prefix of the output files, if None a temporary directory will be used
        tmscore_threshold (float): the tm-score threshold used for clustering
        coverage_threshold (float): the coverage threshold used for clustering
        seq_id (float): the sequence identity threshold used for clustering
        coverage_mode (int): mode used by mmseqs/foldseek to compute coverage

    Returns:
        Dictionary mapping cluster indices to members
    """

    with tempfile.TemporaryDirectory() as temp_dir:

        with tempfile.TemporaryDirectory() as temp_out_dir:
            if out_prefix is None:
                out_prefix = os.path.join(temp_out_dir, "output")

            res = subprocess.run(
                "/content/foldseek/bin/foldseek easy-cluster "
                + input_dir
                + " "
                + out_prefix
                + " "
                + temp_dir
                + " -c  "
                + str(coverage_threshold)
                + " --min-seq-id "
                + str(seq_id)
                + " --tmscore-threshold "
                + str(tmscore_threshold)
                + " --cov-mode "
                + str(coverage_mode)
                + " --single-step-clustering",
                shell=True,
            )
            assert res.returncode == 0, "Something went wrong with foldseek"

            cluster_idx_to_members = parse_foldseek_cluster_results(out_prefix + "_cluster.tsv")

    return cluster_idx_to_members

!chmod +x '/content/foldseek/bin/foldseek'

# Get foldseek clusters
clusters = foldseek_cluster(input_dir=pdb_sample_dir, tmscore_threshold=tmscore_threshold,
                            coverage_threshold=coverage_threshold, seq_id=seq_id)
n_clusters = len(clusters)
print(f'{n_clusters} clusters detected')

# Write foldseek clusters to output dir
import json

with open(os.path.join(output_dir, 'foldseek_clusters.json'), 'w') as json_handle:
    json.dump(clusters, json_handle)


# Write XTC with one sample per cluster only
cluster_trajs = []
for _cluster_idx, samples in clusters.items():
    sample = list(samples)[0] # Choose first sample in cluster
    pdb_file = os.path.join(pdb_sample_dir, f"{sample}.pdb")
    traj = mdtraj.load_pdb(pdb_file)
    cluster_trajs.append(traj)
joint_traj = mdtraj.join(cluster_trajs)
cluster_topology_file = os.path.join(output_dir, "clustered_topology.pdb")
cluster_trajectory_file = os.path.join(output_dir, "clustered_samples.xtc")
joint_traj[0].save_pdb(cluster_topology_file)
joint_traj.save_xtc(cluster_trajectory_file)


In [None]:
#@title Display structure
import os
import ipywidgets as widgets
import py3Dmol
from IPython.display import display, clear_output

# Create interactive widgets for cluster and sample selection.
cluster_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=n_clusters - 1,
    step=1,
    description='Cluster No:',
    continuous_update=False
)
sample_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=0,  # will update based on the selected cluster
    step=1,
    description='Sample Idx:',
    continuous_update=False
)
display(cluster_slider, sample_slider)

# Function to visualize a PDB file using py3Dmol.
def show_pdb(pdb_file: str, show_sidechains: bool = False, show_mainchains: bool = True):
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    try:
        with open(pdb_file, 'r') as f:
            pdb_content = f.read()
    except FileNotFoundError:
        print(f"File not found: {pdb_file}")
        return None
    view.addModel(pdb_content, 'pdb')
    view.setStyle({'cartoon': {'color': 'spectrum'}})

    if show_sidechains:
        BB = ['C', 'O', 'N']
        view.addStyle({'and': [{'resn': ["GLY", "PRO"], 'invert': True}, {'atom': BB, 'invert': True}]},
                      {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
        view.addStyle({'and': [{'resn': "GLY"}, {'atom': 'CA'}]},
                      {'sphere': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
        view.addStyle({'and': [{'resn': "PRO"}, {'atom': ['C', 'O'], 'invert': True}]},
                      {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
    if show_mainchains:
        BB = ['C', 'O', 'N', 'CA']
        view.addStyle({'atom': BB}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})

    view.zoomTo()
    return view

# Helper to update the sample slider's maximum value based on the selected cluster.
def update_sample_slider(cluster_no):
    available_samples = list(clusters[cluster_no])
    sample_slider.max = max(len(available_samples) - 1, 0)
    # Reset sample_slider's value if it's out of range.
    if sample_slider.value > sample_slider.max:
        sample_slider.value = 0

# Main function to update the viewer whenever widget values change.
def update_view(change=None):
    cluster_no = cluster_slider.value
    update_sample_slider(cluster_no)
    available_samples = list(clusters[cluster_no])
    sample_idx = sample_slider.value

    clear_output(wait=True)
    display(cluster_slider, sample_slider)

    if sample_idx >= len(available_samples):
        print(f"Only {len(available_samples)} samples available in cluster {cluster_no}")
        return

    chosen_sample = available_samples[sample_idx]
    pdb_path = os.path.join("pdb_samples", f"{chosen_sample}.pdb")

    # Check if the file exists before attempting to open it.
    if not os.path.exists(pdb_path):
        print(f"File not found: {pdb_path}")
        return

    print(f"Displaying sample {sample_idx} from cluster {cluster_no}")
    view = show_pdb(pdb_path)
    if view:
        view.show()

# Observe changes to the slider values.
cluster_slider.observe(update_view, names='value')
sample_slider.observe(update_view, names='value')

# Trigger an initial update.
update_view()


In [None]:
#@title (Optional) Reconstruct sidechains + Run MD relaxation
#@markdown - `reconstruct_sidechains`: whether to reconstruct sidechains via `hpacker`
#@markdown - `run_md`: check to run MD after sidechain reconstruction, otherwise only sidechain reconstruction is performed
#@markdown - `md_protocol`: `LOCAL_MINIMIZATION`: fast but only resolves local problems ; `NVT_EQUIL`: slow but might resolve more severe issues
#@markdown - `one_per_cluster`: Reconstruct sidechains / optionally run MD for only one sample within each foldseek cluster

#@markdown **WARNING**: this step can be quite expensive depending on how many samples you have requested / sequence length. You may want to check the `one_per_cluster` option.


reconstruct_sidechains = True #@param {type: "boolean"}
run_md = True #@param {type:"boolean"}
one_per_cluster = True #@param {type:"boolean"}
md_protocol = "LOCAL_MINIMIZATION" #@param ["LOCAL_MINIMIZATION", "NVT_EQUIL"] {type:"string"}
import bioemu.sidechain_relax
bioemu.sidechain_relax.HPACKER_PYTHONBIN = os.path.join(conda_prefix, '/envs/hpacker/bin/python')

from bioemu.sidechain_relax import main as sidechainrelax
from bioemu.sidechain_relax import MDProtocol
md_protocol = MDProtocol[md_protocol]
os.environ['CONDA_PREFIX_1'] = conda_prefix

if one_per_cluster:
    topology_file = cluster_topology_file
    trajectory_file = cluster_trajectory_file

prefix = 'hpacker-openmm'
if reconstruct_sidechains:
    relaxed_dir = os.path.join(output_dir, prefix)
    os.makedirs(relaxed_dir, exist_ok=True)
    sidechainrelax(pdb_path=topology_file, xtc_path=trajectory_file,
                  outpath=relaxed_dir, prefix=prefix, md_protocol=md_protocol,
                  md_equil=run_md)
    if run_md:
        os.system(f'touch {relaxed_dir}/.RELAXED')
