<a href="https://colab.research.google.com/github/sokrypton/embo-2025/blob/main/alphafold2_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AlphaFold - single sequence input
- WARNING - For DEMO and educational purposes only.
- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/gamma/af/examples/predict.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.

### Tips and Instructions
 - For 3D display, hold mouseover aminoacid to get name and position number
 - use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")



In [None]:
%%time
# @title Setup
#@markdown  click the little ▶ play icon to the left to setup.

import os
import sys
import re
import time
import numpy as np
from typing import Dict, List, Tuple, Optional, Any

# IPython and display imports
from IPython.utils import io
from IPython.display import HTML
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
import tqdm.notebook

# JAX imports
import jax
import jax.numpy as jnp

# Progress bar format
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

# Global setup flags
SETUP_DONE = False
LIBRARY_IMPORTED = False
DEVICE = None

# Setup environment
if not SETUP_DONE:
    print("Setting up AlphaFold environment...")

    # Download parameters if not already present
    if not os.path.isdir("params"):
        os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py")
        print("Installing ColabDesign...")
        os.system("(mkdir params; apt-get install aria2 -qq; \
        aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \
        tar -xf alphafold_params_2021-07-14.tar -C params; \
        touch params/done.txt )&")

        os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@beta")
        os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

        # Wait for parameter download
        if not os.path.isfile("params/done.txt"):
            print("Downloading AlphaFold params...")
            while not os.path.isfile("params/done.txt"):
                time.sleep(5)

    # Configure JAX device
    if int(jax.__version__.split(".")[1]) > 3:
        os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"

    # Try TPU setup
    try:
        import jax.tools.colab_tpu
        jax.tools.colab_tpu.setup_tpu()
        print('Running on TPU')
        DEVICE = "tpu"
    except:
        if jax.local_devices()[0].platform == 'cpu':
            print("WARNING: no GPU detected, will be using CPU")
            DEVICE = "cpu"
        else:
            print('Running on GPU')
            DEVICE = "gpu"

    sys.path.append('af_backprop')
    SETUP_DONE = True
    print(f"Environment setup complete. Using device: {DEVICE}")

# Import ColabDesign libraries
if not LIBRARY_IMPORTED:
    print("Importing AlphaFold libraries...")

    try:
        from colabdesign.af.loss import get_plddt, get_pae
        from colabdesign.af.prep import prep_input_features
        from colabdesign.af.inputs import update_seq, update_aatype
        from colabdesign.af.alphafold.common import protein
        from colabdesign.af.alphafold.model import data, config, model
        from colabdesign.af.alphafold.common import residue_constants
        from colabdesign.rf.utils import make_animation
        import py3Dmol
        import colabfold as cf

        # Setup model configuration
        cfg = config.model_config("model_5_ptm")
        cfg.model.num_recycle = 0
        cfg.model.global_config.subbatch_size = None
        model_name = "model_2_ptm"

        model_params = data.get_model_haiku_params(
            model_name=model_name,
            data_dir=".",
            fuse=True
        )
        model_runner = model.RunModel(cfg, model_params)

        LIBRARY_IMPORTED = True
        print("Libraries imported successfully.")

    except ImportError as e:
        print(f"Error importing ColabDesign libraries: {e}")
        print("Make sure ColabDesign is properly installed.")
        raise


class AlphaFoldPredictor:
    """Encapsulates AlphaFold prediction functionality with optimized memory allocation."""

    def __init__(self, verbose: bool = False):
        self.verbose = verbose
        self.current_seq = ""
        self.r = -1
        self.max_length = -1
        self.runner = None
        self.I = None
        self.outs = []
        self.positions = []
        self.plddts = []
        self.paes = []
        self.last_Ls = []  # Store chain lengths for plotting

    def get_next_power_of_2(self, length: int, min_size: int = 16) -> int:
        """Calculate the next power of 2 that's >= length, with minimum size."""
        if length <= min_size:
            return min_size

        # Find the next power of 2
        power = 1
        while power < length:
            power *= 2
        return power

    def should_recompile(self, length: int) -> bool:
        """Determine if recompilation is needed based on current max_length."""
        if self.max_length == -1:
            return True

        # Recompile if current length exceeds max_length
        if length > self.max_length:
            return True

        # Recompile if we're using a buffer that's much larger than needed
        # (more than 2x the required size, suggesting we could use a smaller power of 2)
        if self.max_length > length * 4:
            return True

        return False

    def setup_model(self, max_len: int):
        """Set up the model with the specified maximum length."""
        if self.verbose:
            print(f"Setting up model with max_length: {max_len}")

        seq = "A" * max_len
        length = len(seq)
        inputs = prep_input_features(length)

        def runner(I):
            # Update sequence
            inputs = I["inputs"]
            inputs["prev"] = I["prev"]

            seq_oh = jax.nn.one_hot(I["seq"], 20)[None]
            update_seq(seq_oh, inputs)
            update_aatype(seq_oh, inputs)

            # Mask prediction
            mask = jnp.arange(inputs["residue_index"].shape[0]) < I["length"]
            inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask)
            inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask)
            inputs["residue_index"] = jnp.where(mask, inputs["residue_index"], 0)

            # Get prediction
            key = jax.random.PRNGKey(0)
            outputs = model_runner.apply(I["params"], key, inputs)

            aux = {
                "final_atom_positions": outputs["structure_module"]["final_atom_positions"],
                "final_atom_mask": outputs["structure_module"]["final_atom_mask"],
                "plddt": get_plddt(outputs),
                "pae": get_pae(outputs),
                "length": I["length"],
                "seq": I["seq"],
                "prev": outputs["prev"],
                "residue_idx": inputs["residue_index"]
            }
            return aux

        return jax.jit(runner), {
            "inputs": inputs,
            "params": model_params,
            "length": max_len
        }

    def save_pdb(self, outs: Dict[str, Any], filename: str):
        """Save prediction results to a PDB file."""
        p = {
            "residue_index": outs["residue_idx"] + 1,
            "aatype": outs["seq"],
            "atom_positions": outs["final_atom_positions"],
            "atom_mask": outs["final_atom_mask"],
            "plddt": outs["plddt"]
        }
        p = jax.tree_util.tree_map(lambda x: x[:outs["length"]], p)
        b_factors = 100 * p.pop("plddt")[:, None] * p["atom_mask"]
        p = protein.Protein(**p, b_factors=b_factors)
        pdb_lines = protein.to_pdb(p)

        with open(filename, 'w') as f:
            f.write(pdb_lines)

    def process_sequence(self, sequence: str) -> Tuple[str, List[int], int]:
        """Process and validate the input sequence."""
        # Convert to uppercase first
        sequence = sequence.upper()

        # Standard 20 amino acids
        standard_aa = set("ACDEFGHIKLMNPQRSTVWY")

        # Process sequence character by character
        clean_sequence = ""
        replaced_chars = []

        for char in sequence:
            if char in ["/", ":"]:  # Keep chain breaks
                clean_sequence += char
            elif char in standard_aa:  # Keep standard amino acids
                clean_sequence += char
            elif char.isalpha():  # Non-standard amino acid letters (A-Z not in standard set)
                clean_sequence += "G"
                if char not in replaced_chars:
                    replaced_chars.append(char)
                    print(f"Warning: Non-standard amino acid '{char}' replaced with 'G'")
            # Skip everything else (spaces, numbers, punctuation, etc.)

        # Split by chain breaks to get chain lengths
        Ls = [len(s) for s in clean_sequence.replace(":", "/").split("/")]

        # Remove chain breaks for final sequence
        final_sequence = re.sub(r"[^A-Z]", "", clean_sequence)
        length = len(final_sequence)

        if length == 0:
            raise ValueError("Empty sequence after cleaning")

        return clean_sequence, Ls, length

    def predict_structure(self,
                         sequence: str,
                         recycles: int = 0,
                         color: str = "confidence",
                         show_sidechains: bool = True,
                         show_mainchains: bool = False) -> Dict[str, Any]:
        """
        Main prediction function with optimized memory allocation.

        Args:
            sequence: Amino acid sequence to fold
            recycles: Number of recycling iterations
            color: Coloring scheme for visualization
            show_sidechains: Whether to show side chains
            show_mainchains: Whether to show main chains

        Returns:
            Dictionary containing prediction results
        """
        # Process sequence
        ori_sequence, Ls, length = self.process_sequence(sequence)
        self.last_Ls = Ls  # Store for later plotting

        if self.verbose:
            print(f"Processing sequence of length: {length}")
            print(f"Original sequence: {ori_sequence}")

        # Determine optimal max_length using powers of 2
        if self.should_recompile(length):
            old_max_length = self.max_length
            self.max_length = self.get_next_power_of_2(length)
            if self.verbose:
                print(f"Recompiling: {old_max_length} -> {self.max_length}")
            self.runner, self.I = self.setup_model(self.max_length)
        else:
            if self.verbose:
                print(f"Reusing compiled model with max_length: {self.max_length}")

        # Reset if sequence changed
        if ori_sequence != self.current_seq:
            if self.verbose:
                print("New sequence detected, resetting state...")
            self.outs = []
            self.positions = []
            self.plddts = []
            self.paes = []
            self.r = -1

            # Extract amino acids only for processing
            aa_sequence = re.sub(r"[^A-Z]", "", ori_sequence)

            # Pad sequence to max length
            seq = np.array([
                residue_constants.restype_order.get(aa, 0)
                for aa in aa_sequence
            ])
            seq = np.pad(seq, [0, self.max_length - length], constant_values=-1)

            # Update inputs, restart recycle
            self.I.update({
                "seq": seq,
                "length": length,
                "prev": {
                    'prev_msa_first_row': np.zeros([self.max_length, 256]),
                    'prev_pair': np.zeros([self.max_length, self.max_length, 128]),
                    'prev_pos': np.zeros([self.max_length, 37, 3])
                }
            })

            self.I["inputs"]["use_dropout"] = False
            self.I["inputs"]['residue_index'][:] = cf.chain_break(
                np.arange(self.max_length), Ls, length=32
            )
            self.current_seq = ori_sequence

        # Run prediction with progress bar
        if self.verbose:
            print(f"Running prediction with {recycles} recycles...")
        with tqdm.notebook.tqdm(total=(recycles + 1), bar_format=TQDM_BAR_FORMAT) as pbar:
            # Skip already completed recycles
            p = 0
            while p < min(self.r + 1, recycles + 1):
                pbar.update(1)
                p += 1

            # Run remaining recycles
            while self.r < recycles:
                O = self.runner(self.I)
                O = jax.tree_util.tree_map(lambda x: np.asarray(x), O)
                self.positions.append(O["final_atom_positions"][:length])
                self.plddts.append(O["plddt"][:length])
                self.paes.append(O["pae"][:length, :length])
                self.I["prev"] = O["prev"]
                self.outs.append(O)
                self.r += 1
                pbar.update(1)

        # Store results without automatic visualization
        return {
            'outs': self.outs,
            'positions': self.positions,
            'plddts': self.plddts,
            'paes': self.paes,
            'Ls': Ls,
            'length': length,
            'max_length_used': self.max_length,
            'recycles': recycles
        }

    def plot(self,
             recycle: Optional[int] = None,
             color: str = "confidence",
             show_sidechains: bool = True,
             show_mainchains: bool = False,
             show_confidence: bool = True,
             show_plddt_legend: bool = True,
             size: Tuple[int, int] = (800, 480)) -> Dict[str, Any]:
        """
        Plot the 3D structure and confidence plots after prediction.

        Args:
            recycle: Which recycle to plot (None for last)
            color: Coloring scheme ("confidence"/"lDDT", "chain", "rainbow")
            show_sidechains: Whether to show side chains
            show_mainchains: Whether to show main chains
            show_confidence: Whether to show confidence plot
            show_plddt_legend: Whether to show pLDDT legend
            size: Size of 3D viewer (width, height)

        Returns:
            Dictionary with viewer, plots, and data
        """
        if not self.outs:
            raise ValueError("No prediction results available. Run predict_structure() first.")

        # Use last recycle if not specified
        if recycle is None:
            recycle = len(self.outs) - 1

        if recycle >= len(self.outs):
            raise ValueError(f"Recycle {recycle} not available. Only {len(self.outs)} recycles completed.")

        if color == "confidence":
            color = "lDDT"

        if self.verbose:
            print(f"Plotting prediction at recycle={recycle}")

        # Save PDB file
        pdb_filename = f"out_recycle_{recycle}.pdb"
        self.save_pdb(self.outs[recycle], pdb_filename)

        # Create 3D visualization
        v = cf.show_pdb(pdb_filename, show_sidechains, show_mainchains, color,
                       color_HP=True, size=size, Ls=self.last_Ls)
        v.setHoverable({}, True,
                      '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel("      "+atom.resn+":"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',
                      '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')

        # Show 3D structure
        v.show()

        # Show pLDDT legend if requested and using confidence coloring
        if show_plddt_legend and color == "lDDT":
            legend = cf.plot_plddt_legend()
            legend.show()

        # Create and show confidence plots
        confidence_plot = None
        if show_confidence:
            confidence_plot = cf.plot_confidence(
                self.plddts[recycle] * 100,
                self.paes[recycle],
                Ls=self.last_Ls
            )
            confidence_plot.show()

        # Calculate summary statistics
        mean_plddt = np.mean(self.plddts[recycle]) * 100
        min_plddt = np.min(self.plddts[recycle]) * 100
        max_plddt = np.max(self.plddts[recycle]) * 100
        mean_pae = np.mean(self.paes[recycle])

        if self.verbose:
            print(f"\n=== Prediction Summary (Recycle {recycle}) ===")
            print(f"Mean pLDDT: {mean_plddt:.2f}")
            print(f"pLDDT range: {min_plddt:.2f} - {max_plddt:.2f}")
            print(f"Mean PAE: {mean_pae:.2f} Å")

    def create_animation(self) -> HTML:
        """Create animation of the folding trajectory."""
        if len(self.positions) <= 1:
            if self.verbose:
                print("No animation available (need more than 1 recycle)")
            return HTML("<p>No animation available (need more than 1 recycle)</p>")

        animation_html = make_animation(
            np.asarray(self.positions)[..., 1, :],
            np.asarray(self.plddts) * 100.0,
            Ls=self.last_Ls,
            ref=-1,
            align_to_ref=True,
            verbose=self.verbose
        )
        return HTML(animation_html)

if "predictor" not in globals():
    predictor = AlphaFoldPredictor()

In [None]:
#@title Enter the amino acid sequence to fold ⬇️
#@markdown click the little ▶ play icon to the left to run.

# collect user inputs
sequence = 'GGGGGGGGG' #@param {type:"string"}
recycles = 0 #@param ["0", "1", "2", "4", "8", "16", "32", "64"] {type:"raw"}

#@markdown #### Display options
color = "confidence" #@param ["chain", "confidence", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

results = predictor.predict_structure(
  sequence=sequence,
  recycles=recycles,
)

predictor.plot(
  color=color,
  show_sidechains=show_sidechains,
  show_mainchains=show_mainchains,
  show_confidence=True
)

In [None]:
#@title Animate
#@markdown - Animate trajectory if more than 0 recycle(s)
predictor.create_animation()