# **Native sequence redesign with ProteinMPNN and evolutionary information (v2)**


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pb3lab/AI4PD_2025/blob/main/tutorials/native_enzdes/native_enzdes_v2.ipynb)

## Introduction

In this notebook, we will exemplify how to combine ProteinMPNN with experimental information about active site residues and evolutionary information ‚Äì in the form conservation of residues at difference sequence identity percentages from a multiple sequence alignment (MSA) ‚Äì to perform **native sequence redesign of natural enzymes**.

This tutorial is largely based on the breakthough article by **Kiera Sumida**, published in [JACS in 2024](https://pubs.acs.org/doi/10.1021/jacs.3c10941), with some key differences and similarities:
1) As in the original article, active site residues are defined as residues containing backbone atoms within 7 √Ö or sidechains atoms within 6 √Ö of the ligand.
2) We are using an MSA generated using [MMseqs2](https://github.com/soedinglab/MMseqs2) instead of [HHblits](https://github.com/soedinglab/hh-suite) for retrieven homologous enzymes to generate an MSA. The primary reason for using MMseqs2 over HHblits is its massive speed advantage.
3) In the original article, four iterative HHblits searches were performed against the UniRef30 database at E-value cutoffs of 1e-50, 1e-30, 1e-10 and 1e-4. Here, the database being used by MMseqs2 are built from extensive sequence sets like UniRef30 and other environmental sequences
4) We maintained the filtering of the sequences in the MSA at 90% sequence identity, 50% coverage and 30% minimum query identity.
5) In the original article, each position in the MSA was ranked based on how highly conserved the most frequent amino acid identity was, selecting the top 30%, 50%, and 70% most conserved positions to fix. Here, we are only fixing the top 70% most conserved positions.
6) We are using the same ProteinMPNN model, trained with 0.2 √Ö applied to the training set of protein bacbones.
7) Three sampling temperatures were tried for ProteinMPNN during the protein sequence generation stage (0.1, 0.2, 0.3), whereas only one 0.2 is used in this tutorial.
8) Only 4 sequences are being generated, whereas 144 sequences were generated in the original article.
9) Only model 4 is being used for the AlphaFold predictions, as in the original article, but we are only filtering candidates based on pLDDT >85, and not on CŒ± RMSD < 2 √Ö. This will be added in the future.

# Part 0. Install the different packages required to run this tutorial

### **Please install all the different dependencies at the beginning of the tutorial in the order they are indicated in this notebook.**

In [None]:
#@title 1) Install ProteinMPNN
!pip install jupyter_bokeh --quiet
import json, time, os, sys, glob

if not os.path.isdir("ProteinMPNN"):
  os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

In [None]:
#@title 2) Setup ProteinMPNN model
import matplotlib.pyplot as plt
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
#v_48_010=version with 48 edges 0.10A noise
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]


backbone_noise=0.00               # Standard deviation of Gaussian noise to add to backbone atoms

path_to_model_weights='/content/ProteinMPNN/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

checkpoint = torch.load(checkpoint_path, map_location=device)
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded")

In [None]:
#@title 3) Helper functions for ProteinMPNN
def make_tied_positions_for_homomers(pdb_dict_list):
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1,chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i] #needs to be a list
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict

In [None]:
#@title 4) Setup ColabFold without AMBER (i.e. without relax)

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

use_amber = False
use_templates = True
python_version = python_version

In [None]:
#@title 5) Install dependencies for running MMseqs2 and ColabFold
%%time
%%bash -s $use_amber $use_templates $python_version

set -e

USE_AMBER=$1
USE_TEMPLATES=$2
PYTHON_VERSION=$3

if [ ! -f COLABFOLD_READY ]; then
  # install dependencies
  # We have to use "--no-warn-conflicts" because colab already has a lot preinstalled with requirements different to ours
  pip install -q --no-warn-conflicts "colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold"
  if [ -n "${TPU_NAME}" ]; then
    pip install -q --no-warn-conflicts -U dm-haiku==0.0.10 jax==0.3.25
  fi
  ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold
  ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold
  # hack to fix TF crash
  rm -f /usr/local/lib/python3.*/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so
  touch COLABFOLD_READY
fi

# Download params (~1min)
python -m colabfold.download

# setup conda
if [ ${USE_AMBER} == "True" ] || [ ${USE_TEMPLATES} == "True" ]; then
  if [ ! -f CONDA_READY ]; then
    wget -qnc https://github.com/conda-forge/miniforge/releases/download/25.3.1-0/Miniforge3-25.3.1-0-Linux-x86_64.sh
    bash Miniforge3-25.3.1-0-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null
    rm Miniforge3-25.3.1-0-Linux-x86_64.sh
    conda config --set auto_update_conda false
    touch CONDA_READY
  fi
fi
# setup template search
if [ ${USE_TEMPLATES} == "True" ] && [ ! -f HH_READY ]; then
  conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python="${PYTHON_VERSION}" 2>&1 1>/dev/null
  touch HH_READY
fi
# setup openmm for amber refinement
if [ ${USE_AMBER} == "True" ] && [ ! -f AMBER_READY ]; then
  conda install -y -q -c conda-forge openmm=8.2.0 python="${PYTHON_VERSION}" pdbfixer 2>&1 1>/dev/null
  touch AMBER_READY
fi

### **Once you have finished installing these dependencies, we are ready to perform the tutorial**

# Part 1. Determine the active site residues and highly conserved residues of the enzyme to fix during sequence redesign

### Please follow the steps in the order they are presented in this tutorial

In [None]:
#@title 1) Load Structure, Extract Sequence, and Split Protein and Ligand Chains
#@markdown ---
#@markdown ### 1. PDB Input
#@markdown Choose to download from RCSB or upload your own file.
pdb_id = 'AZO1' #@param {type:"string"}
use_upload = True #@param {type:"boolean"}
#@markdown > If uploading, the script will assume the first chain is the protein.
#@markdown ---
#@markdown ### 2. Protein (Target) Chain
#@markdown Provide the chain letter for your main protein target.
target_chain_letter = 'A' #@param {type:"string"}
#@markdown ---
#@markdown ### 3. Ligand
#@markdown Choose your method for extracting the ligand.
ligand_extraction_method = "By HETATM Residue Name" #@param ["By Chain", "By HETATM Residue Name"]
#@markdown **If 'By Chain':** Provide chain letter.
#@markdown **If 'By HETATM Residue Name':** Provide the chain letter where the ligand(s) are located, and their 3-letter name(s).
ligand_chain_letter = 'C' #@param {type:"string"}
ligand_residue_name = 'NAP' #@param {type:"string"}
#@markdown ---

import os
from google.colab import files
import warnings

try:
    from Bio.PDB import (
        PDBParser, PDBIO, Select,
        Polypeptide, CaPPBuilder, PDBExceptions
    )
    from Bio.Seq import Seq
    from Bio.SeqRecord import SeqRecord
    from Bio import SeqIO
except ImportError:
    print("Biopython not found. Installing...")
    !pip install biopython
    from Bio.PDB import (
        PDBParser, PDBIO, Select,
        Polypeptide, CaPPBuilder, PDBExceptions
    )
    from Bio.Seq import Seq
    from Bio.SeqRecord import SeqRecord
    from Bio import SeqIO

# Suppress Biopython warnings
warnings.filterwarnings("ignore", category=PDBExceptions.PDBConstructionWarning)

# --- 0A. Python-based renumbering function for the protein ---
def renumber_protein_pdb(filename):
    """
    Reads a PDB file, keeps only ATOM and TER lines, and renumbers
    residues sequentially from 1, incrementing on 'N' atoms.
    """
    renumbered_lines = []
    current_res_num = 0

    try:
        with open(filename, 'r') as f:
            lines = f.readlines()

        for line in lines:
            if line.startswith("ATOM"):
                atom_name = line[12:16].strip()

                # Increment residue number if this is a Nitrogen atom
                # This mimics the `if( $3=="N" ) ++num;` logic
                if atom_name == "N":
                    current_res_num += 1

                # Safety check if PDB doesn't start with N
                if current_res_num == 0:
                    current_res_num = 1

                new_res_num_str = str(current_res_num).rjust(4)

                # Splice in the new res num
                new_line = line[:22] + new_res_num_str + line[26:]
                renumbered_lines.append(new_line)

            elif line.startswith("TER"):
                renumbered_lines.append(line)

        # Overwrite the original file
        with open(filename, 'w') as f:
            f.writelines(renumbered_lines)

        print(f"‚úÖ Successfully filtered and renumbered {filename}")
        return current_res_num # Return the last residue number used

    except Exception as e:
        print(f"üî• Error renumbering protein {filename}: {e}")
        return 0

# --- 0B. Python-based renumbering function for ligands ---
def renumber_and_filter_ligand(filename, start_res_num, new_chain_letter, is_last_file=False):
    """
    Reads a PDB file, filters for ATOM/HETATM/TER, renumbers residues,
    and sets a new chain ID. Adds END instead of TER if it's the last file.
    """
    renumbered_lines = []
    current_offset = -1 # Will be 0 on the first new residue
    last_original_res_num_str = None

    # Ensure chain letter is a single character
    if not isinstance(new_chain_letter, str) or len(new_chain_letter) != 1:
        print(f"‚ö†Ô∏è Invalid chain letter '{new_chain_letter}'. Defaulting to 'X'.")
        new_chain_letter = 'X'

    try:
        with open(filename, 'r') as f:
            lines = f.readlines()

        for line in lines:
            if line.startswith("ATOM") or line.startswith("HETATM"):
                original_res_num_str = line[22:26] # Keep as string for comparison

                if original_res_num_str != last_original_res_num_str:
                    current_offset += 1 # Increment for each new residue
                    last_original_res_num_str = original_res_num_str

                new_res_num = start_res_num + current_offset
                new_res_num_str = str(new_res_num).rjust(4)

                # Splice in the new chain ID and new res num
                new_line = line[:21] + new_chain_letter + new_res_num_str + line[26:]
                renumbered_lines.append(new_line)

            elif line.startswith("TER"):
                # Only append TER if it's NOT the last file
                if not is_last_file:
                    renumbered_lines.append(line)

        # After processing all lines, add END if it IS the last file
        if is_last_file:
            renumbered_lines.append("END\n")

        # Overwrite the original file
        with open(filename, 'w') as f:
            f.writelines(renumbered_lines)

        final_res_num = start_res_num + current_offset
        print(f"‚úÖ Successfully filtered and renumbered {filename} (Chain {new_chain_letter}, Res {start_res_num}-{final_res_num})")
        return final_res_num # Return the last residue number used

    except Exception as e:
        print(f"üî• Error renumbering {filename}: {e}")
        return start_res_num - 1

# --- Biopython Selectors ---
class ProteinSelect(Select):
    """ Selects only the ATOM records for a specific protein chain. """
    def __init__(self, chain_id):
        self.chain_id = chain_id

    def accept_chain(self, chain):
        return chain.id == self.chain_id

    def accept_residue(self, residue):
        # Only accept standard residues (no HETATMs)
        return residue.id[0] == ' '

    def accept_atom(self, atom):
        return True

class LigandChainSelect(Select):
    """ Selects all ATOM/HETATM records for a specific ligand chain. """
    def __init__(self, chain_id):
        self.chain_id = chain_id

    def accept_chain(self, chain):
        return chain.id == self.chain_id

    def accept_residue(self, residue):
        return True # Accept both ATOM and HETATM

    def accept_atom(self, atom):
        return True

class HetatmResidueSelect(Select):
    """ Selects only a specific HETATM residue. """
    def __init__(self, target_residue):
        self.target_residue = target_residue

    def accept_chain(self, chain):
        return chain.id == self.target_residue.get_parent().id

    def accept_residue(self, residue):
        return residue == self.target_residue

    def accept_atom(self, atom):
        return True

# --- 1. Handle PDB Input (Upload vs. Download) ---
full_pdb_filename = ""
pdb_id_base = ""

if use_upload:
    print("Please upload your PDB file...")
    try:
        uploaded = files.upload()
        if not uploaded:
            raise Exception("File upload cancelled.")

        uploaded_filename = next(iter(uploaded)) # Get the first filename
        pdb_content = uploaded[uploaded_filename].decode('utf-8') # Get the file content

        full_pdb_filename = "uploaded_file.pdb"
        with open(full_pdb_filename, 'w') as f:
            f.write(pdb_content) # Write the content to a local file

        pdb_id_base = os.path.splitext(uploaded_filename)[0] # Use file name as base
        print(f"Using uploaded file: {uploaded_filename} (saved as {full_pdb_filename})")
    except Exception as e:
        print(f"üî• Error during file upload: {e}")
        raise SystemExit("File upload failed.")
else:
    pdb_id_base = pdb_id
    full_pdb_filename = f"{pdb_id_base}.pdb"
    if not os.path.isfile(full_pdb_filename):
        print(f"Downloading {full_pdb_filename} from RCSB PDB...")
        !wget -q -O {full_pdb_filename} https://files.rcsb.org/download/{full_pdb_filename}
        print(f"Successfully downloaded {full_pdb_filename}.")
    else:
        print(f"{full_pdb_filename} already exists. Using local file.")

# --- 2. Define Output Filenames ---
target_fasta_filename = "protein.fasta"
target_pdb_filename = "protein.pdb"

# --- 3. Load Structure and Process Chains ---
if full_pdb_filename and os.path.exists(full_pdb_filename):
    try:
        print(f"\nLoading full structure from {full_pdb_filename}...")
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_id_base, full_pdb_filename)
        model = structure[0] # Assume first model
        last_protein_res_num = 0

        # --- 3A. Process Target Chain (Protein) ---
        print(f"\nProcessing target chain: {target_chain_letter}")

        if target_chain_letter not in model:
            if use_upload:
                original_letter = target_chain_letter
                target_chain_letter = next(iter(model.child_dict.keys()))
                print(f"Warning: Chain '{original_letter}' not found. Using first chain '{target_chain_letter}' as protein.")
            else:
                print(f"‚ö†Ô∏è Error: Target protein chain '{target_chain_letter}' not found in PDB. Stopping.")
                raise SystemExit()

        protein_chain = model[target_chain_letter]

        # Get sequence
        sequence = ""
        for res in protein_chain.get_residues():
            # Check if it's a standard amino acid
            if res.id[0] == ' ' and Polypeptide.is_aa(res.get_resname(), standard=True):
                try:
                    # FIX: Use the protein_letters_3to1 dictionary instead of the old function
                    sequence += Polypeptide.protein_letters_3to1[res.get_resname()]
                except KeyError:
                    sequence += 'X' # Unknown residue

        # MODIFIED: Set description to "" to only get ">protein"
        seq_record = SeqRecord(Seq(sequence), id="protein", description="")
        with open(target_fasta_filename, 'w') as f:
            SeqIO.write(seq_record, f, "fasta")
        print(f"‚úÖ Successfully saved sequence to {target_fasta_filename}")

        # Save protein structure (ATOM lines only)
        io = PDBIO()
        io.set_structure(structure)
        io.save(target_pdb_filename, ProteinSelect(target_chain_letter))
        print(f"‚úÖ Successfully saved structure to {target_pdb_filename}")

        # Renumber protein PDB
        last_protein_res_num = renumber_protein_pdb(target_pdb_filename)
        if last_protein_res_num > 0:
            print(f"Protein renumbered. Last residue number is: {last_protein_res_num}")
        else:
            print("‚ö†Ô∏è Warning: Protein renumbering failed or protein was empty.")

        # --- 3B. Process Ligand(s) ---
        current_ligand_res_start = last_protein_res_num + 1
        start_chain_ord = ord(target_chain_letter.upper())
        ligand_chain_counter = 1 # To assign B, C, D...

        if ligand_extraction_method == "By Chain":
            print(f"\nProcessing ligand by chain: {ligand_chain_letter}")

            if ligand_chain_letter not in model:
                print(f"‚ö†Ô∏è Error: Ligand chain '{ligand_chain_letter}' not found.")
            else:
                ligand_pdb_filename = f"{ligand_chain_letter}_ligand.pdb"
                io.save(ligand_pdb_filename, LigandChainSelect(ligand_chain_letter))
                print(f"‚úÖ Successfully saved structure to {ligand_pdb_filename}")

                # Renumber this ligand chain
                current_new_chain_letter = chr(start_chain_ord + ligand_chain_counter)
                renumber_and_filter_ligand(
                    ligand_pdb_filename,
                    current_ligand_res_start,
                    current_new_chain_letter,
                    is_last_file=True # Assume this is the only ligand
                )

        elif ligand_extraction_method == "By HETATM Residue Name":
            if not ligand_residue_name:
                print("‚ö†Ô∏è Error: Please provide a HETATM residue name to search for.")
            else:
                ligand_names_to_find = [name.strip().upper() for name in ligand_residue_name.split(',')]
                print(f"\nProcessing ligands by HETATM name(s): {ligand_names_to_find}")

                if ligand_chain_letter not in model:
                     print(f"‚ö†Ô∏è Error: Specified ligand chain '{ligand_chain_letter}' not found. Cannot search for HETATMs.")
                else:
                    print(f"Searching for HETATMs on chain '{ligand_chain_letter}'...")

                    search_chain = model[ligand_chain_letter]
                    matching_residues = []

                    for res in search_chain.get_residues():
                        res_name_upper = res.get_resname().strip().upper()
                        # Check if it's a HETATM and matches one of the names
                        if res.id[0] != ' ' and res_name_upper in ligand_names_to_find:
                            matching_residues.append(res)

                    found_ligands = len(matching_residues)
                    ligand_file_counts = {}

                    if found_ligands == 0:
                        print(f"‚ö†Ô∏è No HETATM residues matching {ligand_names_to_find} found on chain {ligand_chain_letter}.")
                    else:
                        print(f"Found {found_ligands} matching HETATM residues. Processing...")

                        for i, res in enumerate(matching_residues):
                            is_last = (i == found_ligands - 1) # Check if this is the last item
                            res_name_upper = res.get_resname().strip().upper()

                            # Determine chain letter
                            current_new_chain_letter = chr(start_chain_ord + ligand_chain_counter)
                            ligand_chain_counter += 1 # Increment for the *next* ligand

                            # Handle file naming to prevent overwrites
                            count = ligand_file_counts.get(res_name_upper, 0) + 1
                            ligand_file_counts[res_name_upper] = count

                            ligand_pdb_filename = f"{res_name_upper}.pdb"
                            if count > 1:
                                ligand_pdb_filename = f"{res_name_upper}_{count}.pdb"

                            # Save this single residue to a file
                            io.save(ligand_pdb_filename, HetatmResidueSelect(res))
                            print(f"‚úÖ Extracted HETATM {i+1}/{found_ligands} ({res_name_upper}) to {ligand_pdb_filename}")

                            # Call renumbering with the new chain letter and is_last flag
                            last_num_used = renumber_and_filter_ligand(
                                ligand_pdb_filename,
                                current_ligand_res_start,
                                current_new_chain_letter,
                                is_last_file=is_last
                            )

                            # Update the *next* starting number
                            current_ligand_res_start = last_num_used + 1

    except Exception as e:
        print(f"üî• Error during PDB processing: {e}")
else:
    print("üî• Error: PDB file not found or specified. Halting.")

print("\nProcess complete.")

In [None]:
import os
import tarfile
import requests
import subprocess
from glob import glob
import warnings

# --- Configuration ---
#@title 2) Define Active Site Residues Around Ligand
#@markdown ### 1. Input Files
protein_pdb_file = "protein.pdb"  #@param {type:"string"}
ligand_pdb_files_str = "NAP.pdb"  #@param {type:"string"}

#@markdown ### 2. Ligand Parameters
#@markdown **Important:** The chains here MUST match the chains you assigned
#@markdown in the previous PDB splitting step (e.g., 'B', 'C', etc.)
ligand_chains_str = "B"   #@param {type:"string"}

#@markdown ### 3. Output File
merged_pdb_file = "protein_ligand_complex.pdb"  #@param {type:"string"}

#@markdown ### 4. Active Site Definition
#@markdown ---
#@markdown Define the active site by two criteria:
#@markdown 1.  Protein **backbone** atoms ('N', 'CA', 'C', 'O') within X angstroms of the ligand.
#@markdown 2.  Protein **sidechain** atoms (all others) within Y angstroms of the ligand.
cutoff_backbone = 7.0  #@param {type:"number"}
cutoff_sidechain = 6.0  #@param {type:"number"}

# --- 1. Install/Import Biopython ---
try:
    from Bio.PDB import PDBParser, NeighborSearch, PDBExceptions
    from Bio.PDB.Atom import Atom
except ImportError:
    print("Biopython not found. Installing...")
    !pip install biopython
    from Bio.PDB import PDBParser, NeighborSearch, PDBExceptions
    from Bio.PDB.Atom import Atom

# Suppress Biopython warnings
warnings.filterwarnings("ignore", category=PDBExceptions.PDBConstructionWarning)

# --- 2. Merge Protein and Ligand PDBs ---
# This step is necessary to create a single complex for Biopython to analyze
print("\n--- Merging PDB Files ---")

processed_ligand_pdbs = [f.strip() for f in ligand_pdb_files_str.split(',')]

with open(merged_pdb_file, 'w') as outfile:
    if os.path.exists(protein_pdb_file):
        with open(protein_pdb_file, 'r') as infile:
            for line in infile:
                # Filter out CRYST1, REMARK, and any existing END
                if not line.startswith(("CRYST1", "REMARK", "END")):
                    outfile.write(line)
        print(f"Added {protein_pdb_file} to {merged_pdb_file}")
    else:
        print(f"‚ö†Ô∏è Warning: Protein file {protein_pdb_file} not found. Merged file will only contain ligands.")

    # Add a TER record if protein was added and it didn't end with one
    if os.path.exists(protein_pdb_file):
         with open(protein_pdb_file, 'r') as f:
            lines = f.readlines()
            if lines and not lines[-1].startswith("TER"):
                outfile.write("TER\n")

    for ligand_file in processed_ligand_pdbs:
        if not os.path.exists(ligand_file):
            print(f"‚ö†Ô∏è Warning: Ligand file {ligand_file} not found. Skipping.")
            continue

        with open(ligand_file, 'r') as infile:
            for line in infile:
                 # Don't add the "END" line from the ligand, we'll add one at the very end
                 if not line.startswith("END"):
                    outfile.write(line)
        print(f"Added {ligand_file} to {merged_pdb_file}")

    outfile.write("END\n")

print(f"‚úÖ Merged file created: {merged_pdb_file}")


# --- 3. Load Structure and Prepare Atoms ---
print("\n--- Loading Structure with Biopython ---")
try:
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("complex", merged_pdb_file)
    model = structure[0] # Assume first model
except Exception as e:
    print(f"üî• Error parsing merged PDB file: {e}")
    print("   Please check the merged file for errors.")
    raise SystemExit

ligand_chains = [c.strip() for c in ligand_chains_str.split(',')]
protein_chains = []
for chain in model:
    if chain.id not in ligand_chains:
        protein_chains.append(chain.id)

print(f"Protein chains identified: {protein_chains}")
print(f"Ligand chains identified: {ligand_chains}")

# Define backbone atom names
backbone_atoms = {'N', 'CA', 'C', 'O'}

ligand_atoms = []
protein_bb_atoms = []
protein_sc_atoms = []

print("Separating protein (backbone/sidechain) and ligand atoms...")
for chain in model:
    if chain.id in ligand_chains:
        for atom in chain.get_atoms():
            ligand_atoms.append(atom)

    elif chain.id in protein_chains:
        for residue in chain.get_residues():
            # Only process standard protein residues
            if residue.id[0] == ' ': # ' ' indicates a standard residue
                for atom in residue.get_atoms():
                    if atom.name in backbone_atoms:
                        protein_bb_atoms.append(atom)
                    else:
                        protein_sc_atoms.append(atom)

print(f"Found {len(ligand_atoms)} ligand atoms.")
print(f"Found {len(protein_bb_atoms)} protein backbone atoms.")
print(f"Found {len(protein_sc_atoms)} protein sidechain atoms.")

if not ligand_atoms:
    print("üî• Error: No ligand atoms found. Cannot proceed.")
    raise SystemExit

# --- 4. Perform Neighbor Search ---
print("Finding nearby residues...")

# Create the NeighborSearch object from all ligand atoms
ns = NeighborSearch(ligand_atoms)

# Find all protein backbone atoms within the backbone cutoff
nearby_bb_atoms = set()
for atom in protein_bb_atoms:
    nearby = ns.search(atom.coord, cutoff_backbone, 'A') # 'A' = atom level
    if nearby:
        nearby_bb_atoms.add(atom)

# Find all protein sidechain atoms within the sidechain cutoff
nearby_sc_atoms = set()
for atom in protein_sc_atoms:
    nearby = ns.search(atom.coord, cutoff_sidechain, 'A') # 'A' = atom level
    if nearby:
        nearby_sc_atoms.add(atom)

print(f"Found {len(nearby_bb_atoms)} backbone atoms within {cutoff_backbone} √Ö.")
print(f"Found {len(nearby_sc_atoms)} sidechain atoms within {cutoff_sidechain} √Ö.")

# --- 5. Map Atoms to Residues and Finalize List ---
active_site_residues = set()

# Get parent residues for backbone atoms
for atom in nearby_bb_atoms:
    active_site_residues.add(atom.get_parent())

# Get parent residues for sidechain atoms
for atom in nearby_sc_atoms:
    active_site_residues.add(atom.get_parent())

print("\n--- Active Site Analysis Complete ---")
if not active_site_residues:
    print("‚ö†Ô∏è No active site residues found with the given cutoffs.")
else:
    print(f"‚úÖ Found {len(active_site_residues)} unique residues in the active site.")

    # Sort residues by chain and residue number
    sorted_residues = sorted(list(active_site_residues),
                             key=lambda res: (res.get_parent().id, res.id[1]))

    active_site_pdb_names = []

    # CRITICAL: Create the global variable for the next script
    # This must be a list of 1-indexed residue numbers
    global active_site_residue_numbers
    active_site_residue_numbers = []

    print("Active Site Residues (Chain + PDB Number):")
    for res in sorted_residues:
        chain_id = res.get_parent().id
        res_num = res.id[1]
        res_name = res.get_resname()

        active_site_pdb_names.append(f"{chain_id}{res_num} ({res_name})")
        active_site_residue_numbers.append(res_num)

    print(", ".join(active_site_pdb_names))

    print("\nCorresponding Residue Numbers (for next script):")
    print(active_site_residue_numbers)

print("\nüéØ Script finished successfully.")

In [None]:
#@title 3) Generate MSA (.a3m) and optionally filter based on different parameters
#@markdown ---
#@markdown ### 1. Specify Input
#@markdown Enter the name of the FASTA file you generated.
input_fasta_file = 'protein.fasta' #@param {type:"string"}

#@markdown ---
#@markdown ### 2. HHfilter Options
#@markdown Check the box to run hhfilter on the generated .a3m file.
run_hhfilter = True #@param {type:"boolean"}
id_redundancy = 90 #@param {type:"integer"}
coverage = 50 #@param {type:"integer"}
query_identity = 30 #@param {type:"integer"}
#@markdown ---
import os
import sys

# 1. Check if the input FASTA file exists
if not os.path.isfile(input_fasta_file):
    print(f"üî• Error: Input file not found: {input_fasta_file}")
    print("Please make sure the filename matches the one from the previous step.")
    sys.exit(f"File not found: {input_fasta_file}")
else:
    print(f"Found input file: {input_fasta_file}")

# 2. Define the output directory as the current folder
output_dir = "." # This will save files to /content/

# 3. Run the colabfold_batch command
print(f"Running colabfold_batch on {input_fasta_file}...")
print(f"This will generate the .a3m file and then stop.")

# Run the alignment generation
!colabfold_batch {input_fasta_file} {output_dir} --msa-mode "mmseqs2_uniref_env" --msa-only

# 4. Check the output files from colabfold_batch
print("\nAlignment generation complete.")

# --- 5. Run HHfilter (New Step) ---
base_name = os.path.splitext(input_fasta_file)[0]
original_a3m = f"{base_name}.a3m"
filtered_a3m = f"{base_name}.filtered.a3m"

if run_hhfilter:
    print(f"\nRunning hhfilter on {original_a3m}...")
    if not os.path.isfile(original_a3m):
        print(f"üî• Error: The original alignment file {original_a3m} was not found. Skipping filter.")
    else:
        # Build and run the hhfilter command
        !hhfilter -i {original_a3m} -o {filtered_a3m} -id {id_redundancy} -cov {coverage} -qid {query_identity}
        print(f"Filtering complete. Filtered file saved as: {filtered_a3m}")
else:
    print("\nSkipping hhfilter step.")

# --- 6. Find and report the final .a3m file ---
print("\n--- Final Output ---")
try:
    if run_hhfilter and os.path.isfile(filtered_a3m):
        print(f"‚úÖ Your FINAL filtered alignment file is ready: {filtered_a3m}")
    elif os.path.isfile(original_a3m):
        print(f"‚úÖ Your original (unfiltered) alignment file is ready: {original_a3m}")
    else:
        print(f"‚ö†Ô∏è Error: No .a3m file ({original_a3m} or {filtered_a3m}) was found.")

except Exception as e:
    print(f"\n‚ö†Ô∏è Error finding .a3m file: {e}")

In [None]:
#@title 4) Find Conserved Residues based on the MSA
#@markdown ---
#@markdown ### 1. Specify Input Files
#@markdown Enter the name of the filtered .a3m file (from previous step).
filtered_a3m_file = 'protein.filtered.a3m' #@param {type:"string"}
#@markdown
#@markdown ### 2. Conservation Settings
#@markdown Fraction of most-conserved residues to select (e.g., 0.5 = 50%).
frac_conserved = 0.5 #@param {type:"number"}
#@markdown
#@markdown ### 3. Output File
#@markdown Name for the final .txt file containing the fixed positions.
output_txt_file = "fixed_positions.txt" #@param {type:"string"}
#@markdown ---

import os
import sys
import numpy as np

# --- 1. Self-contained A3M Parser (Replaces tools.py) ---
def robust_parse_a3m(a3m_file_path):
    """
    Parses an .a3m file, removes insertions (lowercase), and converts to numbers.
    This version correctly skips header lines before the first sequence.
    """
    # Mapping from AA to number (0-19 are AA, 20 is gap)
    aa_to_num = {
        'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9,
        'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19,
        '-': 20, 'X': 20, 'B': 20, 'Z': 20 # Treat unknowns as gaps
    }

    msa_sequences = []
    seq = ""
    found_first_header = False # Flag to skip junk lines at the start

    try:
        with open(a3m_file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    found_first_header = True # We've found the first sequence, start parsing
                    if seq: # save previous sequence
                        msa_sequences.append(seq)
                    seq = "" # start new sequence
                elif found_first_header and line: # Only append if we're parsing and line is not empty
                    seq += line
    except FileNotFoundError:
        print(f"üî• Error: The alignment file '{a3m_file_path}' was not found.")
        sys.exit(1)
    except Exception as e:
        print(f"üî• Error reading {a3m_file_path}: {e}")
        sys.exit(1)


    if seq: # save last sequence
        msa_sequences.append(seq)

    if not msa_sequences:
        raise ValueError(f"No sequences found in {a3m_file_path}. File might be malformed.")

    # Get query length from the first sequence
    query_seq = msa_sequences[0]
    query_L = len([c for c in query_seq if c.isupper() or c == '-'])

    if query_L == 0:
        raise ValueError(f"Query sequence found, but it has no match/delete characters (Length is 0).")

    msa_aligned = []
    for seq in msa_sequences:
        aligned_seq = ""
        for char in seq:
            # Keep only uppercase (match) and gaps (delete)
            if char.isupper() or char == '-':
                aligned_seq += char

        # Ensure all sequences have the same length as the query
        if len(aligned_seq) == query_L:
            msa_aligned.append(aligned_seq)

    if not msa_aligned:
         raise ValueError(f"No valid, aligned sequences found in {a3m_file_path}.")

    # Convert to numbers
    msa_numeric = []
    for seq in msa_aligned:
        num_seq = [aa_to_num.get(char.upper(), 20) for char in seq] # .upper() for safety
        msa_numeric.append(num_seq)

    return {
        'msa': np.array(msa_aligned),
        'msa_num': np.array(msa_numeric, dtype=int)
    }
# --- End of Parser ---


# --- 2. Check for 'active_site_residue_numbers' variable ---
# This variable should be created by the define_active_site_biopython.py script
if 'active_site_residue_numbers' not in locals() and 'active_site_residue_numbers' not in globals():
    print("üî• Error: The 'active_site_residue_numbers' list was not found.")
    print("   Please re-run the 'Define Active Site (Biopython)' script first.")
    sys.exit(1)
else:
    # Ensure it's a numpy array for union1d
    active_site_np = np.array(active_site_residue_numbers)
    print(f"Found active_site_residue_numbers list with {len(active_site_np)} residues.")

# --- 3. Run Conservation Analysis ---
try:
    if not os.path.isfile(filtered_a3m_file):
        print(f"üî• Error: The alignment file '{filtered_a3m_file}' was not found.")
        sys.exit(1)

    print(f"Parsing alignment file: {filtered_a3m_file}...")
    # Use our new robust parser
    aln = robust_parse_a3m(filtered_a3m_file)

    msa_num = aln['msa_num']
    L = msa_num.shape[1] # Get length (L) from the numeric array
    num_seqs = msa_num.shape[0]

    if num_seqs == 0 or L == 0:
        print(f"üî• Error: The alignment file '{filtered_a3m_file}' contains no valid sequences or has length 0.")
        sys.exit(1)

    print(f"Alignment loaded. Length: {L}, Number of sequences: {num_seqs}")

    # Calculate conservation
    counts = np.stack([np.bincount(column, minlength=21) for column in msa_num.T]).T
    max_count = np.max(counts, axis=0)

    freq = counts / num_seqs

    # Handle columns that are 100% gaps (to avoid divide-by-zero)
    freq_sum = freq[:20].sum(axis=0)
    freq_sum[freq_sum == 0] = 1.0 # Set sum to 1.0 to prevent error

    freq_norm = freq[:20] / freq_sum
    max_freq_norm = np.max(freq_norm, axis=0)

    # Only apply low-count penalty if we have more than one sequence
    if num_seqs > 1:
        max_freq_norm[max_count < 10] = 0 # Don't choose positions with low counts
    else:
        print("Only 1 sequence found, skipping low-count filter.")

    # Save the conserved residues as a list
    num_conserved = int(L * frac_conserved)
    conserved_residues = np.argsort(max_freq_norm)[::-1][:num_conserved] + 1 # make 1-indexed
    conserved_residues.sort()

    print(f"Found {len(conserved_residues)} conserved residues (top {frac_conserved*100}%).")

    # --- 4. Intersect with Active Site ---
    # Use the active_site_np array we defined earlier
    fixed_positions = np.union1d(conserved_residues, active_site_np)
    fixed_positions.sort()

    print("\n--- Final Results ---")
    print(f"Active Site Residues ({len(active_site_np)}):")
    print(active_site_np)
    print(f"Conserved Residues ({len(conserved_residues)}):")
    print(conserved_residues)
    print(f"All Fixed Positions ({len(fixed_positions)}):")
    print(fixed_positions)

    # --- 5. Save Final List to TXT File ---
    print(f"\nSaving final fixed positions to {output_txt_file}...")
    fixed_positions_str = [str(int(res)) for res in fixed_positions]

    with open(output_txt_file, 'w') as f:
        f.write(' '.join(fixed_positions_str))

    print(f"‚úÖ Successfully saved fixed positions to {output_txt_file}.")

except Exception as e:
    print(f"üî• An error occurred: {e}")

# Part 2. Run ProteinMPNN with evolutionary information

In [None]:
import re
import os
import numpy as np

#@markdown ## 1) Setting up ProteinMPNN with fixed positions
#@markdown ### 1. Input PDB
#@markdown Path to the PDB file (must be in your Colab folder).
pdb_path = "protein.pdb" #@param {type:"string"}

homomer = False #@param {type:"boolean"}
designed_chain = "A" #@param {type:"string"}
fixed_chain = "" #@param {type:"string"}

if designed_chain == "":
  designed_chain_list = []
else:
  designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

if fixed_chain == "":
  fixed_chain_list = []
else:
  fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")

chain_list = list(set(designed_chain_list + fixed_chain_list))

#@markdown - `designed_chain`: Chain(s) to design (e.g., "A").
#@markdown - `fixed_chain`: Chain(s) to keep fixed (e.g., "C").

#@markdown ### 2. Design Options
#@markdown Number of sequences to generate.
num_seqs = 1 #@param {type:"integer"}
num_seq_per_target = num_seqs

#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
sampling_temp = "0.2" #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"]

#@markdown - `omit_AAs`: Specify amino acids to omit (e.g., "XC" to omit Cys and Unknown).
omit_AAs = "XC" #@param {type:"string"}
#@markdown ---

#@markdown ### 3. Fixed Positions (Optional)
#@markdown Provide the .txt file of conserved/active site residues to fix.
fixed_positions_file = "fixed_positions.txt" #@param {type:"string"}
#@markdown The chain these fixed positions apply to.
fixed_positions_chain = "A" #@param {type:"string"}
#@markdown ---

# --- (Rest of your script's parameters) ---
save_score=0
save_probs=0
score_only=0
conditional_probs_only=0
conditional_probs_only_backbone=0
batch_size=1
max_length=20000
out_folder='.'
jsonl_path=''
# omit_AAs='X'  <--- THIS LINE IS NOW DELETED (replaced by the form variable)
pssm_multi=0.0
pssm_threshold=0.0
pssm_log_odds_flag=0
pssm_bias_flag=0

##############################################################

folder_for_outputs = out_folder

NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))

# --- New code to read fixed_positions.txt ---
if fixed_positions_file and fixed_positions_chain:
    try:
        with open(fixed_positions_file, 'r') as f:
            # Read the space-separated list of numbers
            residues_to_fix = [int(res) for res in f.read().split()]

        if residues_to_fix:
            # Get the PDB name (basename without .pdb)
            pdb_name = os.path.basename(pdb_path).replace('.pdb', '')

            # Build the dictionary in the format MPNN expects
            fixed_positions_dict = {
                pdb_name: {
                    fixed_positions_chain: residues_to_fix
                }
            }
            print(f"‚úÖ Successfully read {len(residues_to_fix)} fixed positions for chain {fixed_positions_chain} from {fixed_positions_file}.")
            print(f"Fixed positions: {residues_to_fix}")
        else:
            print(f"‚ö†Ô∏è {fixed_positions_file} was found but is empty. No specific residues fixed.")

    except FileNotFoundError:
        print(f"‚ö†Ô∏è {fixed_positions_file} not found. No specific residues fixed (besides 'fixed_chain').")
    except Exception as e:
        print(f"üî• Error reading {fixed_positions_file}: {e}")
else:
    print("No fixed positions file provided. Only 'fixed_chain' (if any) will be fixed.")
# --- End new code ---


###############################################################
# Check if PDB file exists before proceeding
if not os.path.isfile(pdb_path):
    print(f"üî• Error: PDB file not found at {pdb_path}")
    print("Please make sure the file is in your Colab folder.")
else:
    pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

    chain_id_dict = {}
    chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

    print(f"\nChain assignments: {chain_id_dict}")
    for chain in chain_list:
      l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
      print(f"Length of chain {chain} is {l}")

    if homomer:
      tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)
    else:
      tied_positions_dict = None

In [None]:
#@title 2) Run ProteinMPNN with fixed positions
#@markdown ---
#@markdown ### 1. Output FASTA File
output_fasta_file = "generated_sequences.fasta"  #@param {type:"string"}

#@markdown ---
#@markdown ### 2. A3M Generation
source_a3m_file = "protein.filtered.a3m"          #@param {type:"string"}
output_msa_folder = "msa"                        #@param {type:"string"}
#@markdown ---

import os, copy, torch, numpy as np

# --- 1. Read base alignment ---
os.makedirs(output_msa_folder, exist_ok=True)
alignment_body_lines = []
try:
    with open(source_a3m_file, 'r') as f:
        found_first_seq = False
        for line in f:
            line = line.rstrip('\n\r') + '\n'  # normalize newlines
            if line.startswith('>'):
                if not found_first_seq:
                    found_first_seq = True
                    continue  # skip query header
                alignment_body_lines.append(line)
            elif found_first_seq:
                alignment_body_lines.append(line)
    print(f"Read {len(alignment_body_lines)//2} alignment hits from {source_a3m_file}.")
except Exception as e:
    print(f"‚ö†Ô∏è WARNING reading {source_a3m_file}: {e}")

def _alignment_tail_to_write(alignment_lines):
    """Return alignment hits excluding the query sequence."""
    if not alignment_lines:
        return []
    if not alignment_lines[0].startswith('>'):
        return alignment_lines[1:]
    return alignment_lines

# --- 2. Generate sequences ---
with torch.no_grad():
    print('Generating sequences...')
    print(f"Saving generated sequences to: {output_fasta_file}")
    with open(output_fasta_file, 'w') as fasta_f:
        for ix, protein in enumerate(dataset_valid):
            score_list, all_probs_list, all_log_probs_list, S_sample_list = [], [], [], []
            batch_clones = [copy.deepcopy(protein) for _ in range(BATCH_COPIES)]
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(
                batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict,
                pssm_dict, bias_by_res_dict)
            pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float()
            name_ = batch_clones[0]['name']

            randn_1 = torch.randn(chain_M.shape, device=X.device)
            log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
            mask_for_loss = mask * chain_M * chain_M_pos
            native_score = _scores(S, log_probs, mask_for_loss).cpu().data.numpy()

            for temp in temperatures:
                for j in range(NUM_BATCHES):
                    randn_2 = torch.randn(chain_M.shape, device=X.device)
                    if tied_positions_dict is None:
                        sample_dict = model.sample(
                            X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp,
                            omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                            omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias,
                            pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag),
                            pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag),
                            bias_by_res=bias_by_res_all)
                        S_sample = sample_dict["S"]
                    else:
                        sample_dict = model.tied_sample(
                            X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp,
                            omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
                            omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias,
                            pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag),
                            pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag),
                            tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta,
                            bias_by_res=bias_by_res_all)
                        S_sample = sample_dict["S"]

                    log_probs = model(X, S_sample, mask, chain_M * chain_M_pos,
                                      residue_idx, chain_encoding_all, randn_2,
                                      use_input_decoding_order=True,
                                      decoding_order=sample_dict["decoding_order"])
                    mask_for_loss = mask * chain_M * chain_M_pos
                    scores = _scores(S_sample, log_probs, mask_for_loss).cpu().data.numpy()

                    for b_ix in range(BATCH_COPIES):
                        masked_chain_length_list = masked_chain_length_list_list[b_ix]
                        masked_list = masked_list_list[b_ix]
                        seq_recovery_rate = torch.sum(
                            torch.sum(torch.nn.functional.one_hot(S[b_ix],21)
                                      * torch.nn.functional.one_hot(S_sample[b_ix],21), axis=-1)
                            * mask_for_loss[b_ix]) / torch.sum(mask_for_loss[b_ix])
                        seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                        score = scores[b_ix]
                        native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])

                        # --- Native written once ---
                        if b_ix == 0 and j == 0 and temp == temperatures[0]:
                            native_seq = "".join(native_seq.split("/")).strip()
                            native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                            line = f">native, score={native_score_print}\n{native_seq}\n"
                            fasta_f.write(line)
                            print(line.rstrip())

                            a3m_filename = os.path.join(output_msa_folder, f"{name_}_native.a3m")
                            try:
                                with open(a3m_filename, 'w') as a3m_f:
                                    native_len = len(native_seq)
                                    a3m_f.write(f"#{native_len}\t1\n>native\n{native_seq}\n")
                                    tail = _alignment_tail_to_write(alignment_body_lines)
                                    a3m_f.writelines(tail)
                                print(f"Wrote native A3M: {a3m_filename}")
                            except Exception as e:
                                print(f"‚ö†Ô∏è Warning: Could not write native A3M {a3m_filename}. Error: {e}")

                        # --- Generated sequences ---
                        seq = "".join(seq.split("/")).strip()
                        score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                        seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                        sample_index = j * BATCH_COPIES + b_ix + 1
                        line = f">sample{sample_index}, T={temp}, score={score_print}, seq_recovery={seq_rec_print}\n{seq}\n"
                        fasta_f.write(line)
                        print(line.rstrip())

                        sample_name = f"sample{sample_index}"
                        a3m_filename = os.path.join(output_msa_folder, f"{name_}_{sample_name}.a3m")
                        try:
                            with open(a3m_filename, 'w') as a3m_f:
                                seq_len = len(seq)
                                a3m_f.write(f"#{seq_len}\t1\n>{sample_name}\n{seq}\n")
                                tail = _alignment_tail_to_write(alignment_body_lines)
                                a3m_f.writelines(tail)
                            print(f"Wrote sample A3M: {a3m_filename}")
                        except Exception as e:
                            print(f"‚ö†Ô∏è Warning: Could not write sample A3M {a3m_filename}. Error: {e}")

    print(f"\n‚úÖ All sequences saved to {output_fasta_file} and A3Ms in '{output_msa_folder}'.")

In [None]:
#@title 3) Visualize the generated sequences using an MSA Viewer in Google Colab
#The following code is modified from the wonderful viewer developed by Damien Farrell
#https://dmnfarrell.github.io/bioinformatics/bokeh-sequence-aligner

#Importing all modules first
import os, io, random
import string
import numpy as np

from Bio.Seq import Seq
from Bio.Align import MultipleSeqAlignment
from Bio import AlignIO, SeqIO

import panel as pn
import panel.widgets as pnw
pn.extension()

from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, Plot, Grid, Range1d
from bokeh.models.glyphs import Text, Rect
from bokeh.layouts import gridplot

#Setting up the amino color code according to Zappo color scheme
def get_colors(seqs):
    #make colors for bases in sequence
    text = [i for s in list(seqs) for i in s]
    #Use Zappo color scheme
    clrs =  {'K':'red',
             'R':'red',
             'H':'red',
             'D':'green',
             'E':'green',
             'Q':'blue',
             'N':'blue',
             'S':'blue',
             'T':'blue',
             'A':'blue',
             'I':'blue',
             'L':'blue',
             'M':'blue',
             'V':'blue',
             'F':'orange',
             'Y':'orange',
             'W':'orange',
             'C':'blue',
             'P':'yellow',
             'G':'orange',
             '-':'white'}
    colors = [clrs[i] for i in text]
    return colors

#Setting up the MSA viewer
def view_alignment(aln, fontsize="9pt", plot_width=800):
    """Bokeh sequence alignment view"""

    #make sequence and id lists from the aln object
    seqs = [rec.seq for rec in (aln)]
    ids = [rec.id for rec in aln]
    text = [i for s in list(seqs) for i in s]
    colors = get_colors(seqs)
    N = len(seqs[0])
    S = len(seqs)
    width = .4

    x = np.arange(1,N+1)
    y = np.arange(0,S,1)
    #creates a 2D grid of coords from the 1D arrays
    xx, yy = np.meshgrid(x, y)
    #flattens the arrays
    gx = xx.ravel()
    gy = yy.flatten()
    #use recty for rect coords with an offset
    recty = gy+.5
    h= 1/S
    #now we can create the ColumnDataSource with all the arrays
    source = ColumnDataSource(dict(x=gx, y=gy, recty=recty, text=text, colors=colors))
    plot_height = len(seqs)*15+50
    x_range = Range1d(0,N+1, bounds='auto')
    if N>100:
        viewlen=100
    else:
        viewlen=N
    #view_range is for the close up view
    view_range = (0,viewlen)
    tools="xpan, xwheel_zoom, reset, save"

    #entire sequence view (no text, with zoom)
    p = figure(title=None, width= plot_width, height=50,
               x_range=x_range, y_range=(0,S), tools=tools,
               min_border=0, toolbar_location='below')
    rects = Rect(x="x", y="recty",  width=1, height=1, fill_color="colors",
                 line_color=None, fill_alpha=0.6)
    p.add_glyph(source, rects)
    p.yaxis.visible = False
    p.grid.visible = False

    #sequence text view with ability to scroll along x axis
    p1 = figure(title=None, width=plot_width, height=plot_height,
                x_range=view_range, y_range=ids, tools="xpan,reset",
                min_border=0, toolbar_location='below')#, lod_factor=1)
    glyph = Text(x="x", y="y", text="text", text_align='center',text_color="black",
                text_font="monospace",text_font_size=fontsize)
    rects = Rect(x="x", y="recty",  width=1, height=1, fill_color="colors",
                line_color=None, fill_alpha=0.4)
    p1.add_glyph(source, glyph)
    p1.add_glyph(source, rects)

    p1.grid.visible = False
    p1.xaxis.major_label_text_font_style = "bold"
    p1.yaxis.minor_tick_line_width = 0
    p1.yaxis.major_tick_line_width = 0

    p = gridplot([[p],[p1]], toolbar_location='below')
    return p

#Loading the viewer by indicating the MSA file and format to read
#@markdown Name of the MSA file (including the filetype)
MSAfile = 'generated_sequences.fasta' #@param {type:"string"}
MSAformat = 'fasta' #@param {type:"string"}
aln = AlignIO.read(MSAfile,MSAformat)
p = view_alignment(aln, plot_width=900)
pn.pane.Bokeh(p)

# Part 3. Predict  structures of the designed sequences with AF2

In [None]:
#@title 1) Run ColabFold using custom MSAs and a single model
#@markdown ---
#@markdown ### 1. Input / Output Folders
msa_dir = 'msa' #@param {type:"string"}
predictions_dir = 'predictions' #@param {type:"string"}
csv_output_file = 'confidence_metrics.csv' #@param {type:"string"}
#@markdown ---
#@markdown ### 2. Model Settings
#@markdown Specify model number(s) to run (e.g., "3" or "1,2,3,4,5")
model_order = "4" #@param {type:"string"}
#@markdown ---

import os
import sys
import json
import glob
import numpy as np
import pandas as pd
from pathlib import Path

# --- 1. Setup directories ---
result_dir_path = Path(predictions_dir)
msa_dir_path = Path(msa_dir)
os.makedirs(result_dir_path, exist_ok=True)

# --- 2. Run ColabFold using the simple batch command ---
print(f"üöÄ Starting ColabFold batch run...")
print(f"   Input: {msa_dir}")
print(f"   Output: {predictions_dir}")
print(f"   Models: {model_order}")

!colabfold_batch \
  --model-order {model_order} \
  --model-type alphafold2_ptm \
  {msa_dir} \
  {predictions_dir}

print("\n‚úÖ Prediction run complete.")

# --- 3. Gather confidence results ---
print(f"\nüìä Parsing results to create {csv_output_file}...")
results = []

# --- FIX 1 ---
# Search for "_scores_rank_001_*.json" instead of "_unrelaxed_rank_001_*.json"
# Also removed the extra "*/" since there are no sub-folders.
json_files = sorted(result_dir_path.glob("*_scores_rank_001_*.json"))
# --- End FIX 1 ---

if not json_files:
    print(f"üî• Error: No JSON result files found in {predictions_dir}.")
    print("   Please check if the predictions ran correctly and produced output.")
    print(f"   (Was looking for files like: *_scores_rank_001_*.json)")
else:
    print(f"   Found {len(json_files)} result files to parse.")
    for jf in json_files:
        try:
            # --- FIX 2 ---
            # Split the filename by "_scores" to get the jobname
            jobname = jf.name.split("_scores")[0]
            # --- End FIX 2 ---

            data = json.load(open(jf))

            # This part is correct: it calculates the average from the list
            avg_plddt = np.mean(data.get("plddt", [])) if data.get("plddt") else None
            ptm = data.get("ptm", None)

            model_name = "unknown"
            if "model_name" in data:
                model_name = data["model_name"]
            elif "model" in data:
                model_name = data["model"]
            else:
                for i in model_order.split(','):
                    if f"model_{i}" in jf.name:
                        model_name = f"model_{i}"
                        break

            results.append({
                "sequence_name": jobname,
                "avg_plddt": avg_plddt,
                "ptm": ptm,
                "model_used": model_name
            })
        except Exception as e:
            print(f"‚ö†Ô∏è Error parsing {jf.name}: {e}")

    if results:
        df = pd.DataFrame(results).sort_values("avg_plddt", ascending=False)
        out_path = result_dir_path / csv_output_file
        df.to_csv(out_path, index=False)
        print(f"\n‚úÖ Successfully saved confidence metrics to: {out_path}")
        print("\n--- Top Results ---")
        print(df.head())
    else:
        print("üî• No results were successfully parsed. Check predictions directory.")

print("\nüéâ All done.")

In [None]:
#@title 2) Calculate RMSD & Save Aligned Structures
#@markdown ---
#@markdown ### 1. File Locations
#@markdown ---
#@markdown Path to your single experimental/reference PDB file:
experimental_pdb = "protein.pdb" #@param {type:"string"}
#@markdown ---
#@markdown Folder where ColabFold saved the predictions and CSV:
predictions_dir = "predictions" #@param {type:"string"}
#@markdown ---
#@markdown Name of the CSV file to read and update:
csv_output_file = "confidence_metrics.csv" #@param {type:"string"}
#@markdown ---
#@markdown **Folder to save all aligned PDBs for visualization:**
aligned_pdb_folder = "aligned" #@param {type:"string"}
#@markdown ---
#@markdown ### 2. Alignment Options
#@markdown ---
#@markdown Select the method for structural superposition:
alignment_mode = "All Atoms" #@param ["All Atoms", "Iterative Exclusion", "Specific Residues"]
#@markdown ---
#@markdown **For "Iterative Exclusion" mode:**
#@markdown Cutoff in √Ö. Residues with CŒ± distance > cutoff after alignment will be excluded.
iterative_rmsd_cutoff = 2.0 #@param {type:"number"}
#@markdown Maximum number of iterations to run.
iterative_max_cycles = 5 #@param {type:"integer"}
#@markdown ---
#@markdown **For "Specific Residues" mode:**
#@markdown Provide a comma-separated list of residues or ranges (e.g., "10-50, 80, 91-100").
residue_list_to_align = "1-188" #@param {type:"string"}
#@markdown ---

import os
import glob
import pandas as pd
import numpy as np
import sys
import re
from pathlib import Path

# Try to import Biopython, which is needed for RMSD
try:
    from Bio.PDB import PDBParser, Superimposer, PDBIO
except ImportError:
    print(" Biopython not found. Installing...")
    !pip install biopython
    from Bio.PDB import PDBParser, Superimposer, PDBIO

# --- Helper Function to Parse Residue List ---
def parse_residue_list(res_string):
    """Parses a residue string like "10-50, 80, 91-100" into a set of integers."""
    residue_set = set()
    if not res_string:
        return residue_set

    parts = res_string.split(',')
    for part in parts:
        part = part.strip()
        if not part:
            continue
        if '-' in part:
            try:
                start, end = part.split('-')
                start_res = int(start.strip())
                end_res = int(end.strip())
                residue_set.update(range(start_res, end_res + 1))
            except ValueError:
                print(f"‚ö†Ô∏è Warning: Could not parse range '{part}'. Skipping.")
        else:
            try:
                residue_set.add(int(part.strip()))
            except ValueError:
                print(f"‚ö†Ô∏è Warning: Could not parse residue number '{part}'. Skipping.")
    return residue_set

# --- 1. Load Reference Structure & Setup Folders ---
print(f"Loading reference structure: {experimental_pdb}")
pdb_parser = PDBParser(QUIET=True)
io = PDBIO() # Initialize PDB saver

# Create the output folder for aligned PDBs
os.makedirs(aligned_pdb_folder, exist_ok=True)

try:
    ref_structure = pdb_parser.get_structure("reference", experimental_pdb)
except FileNotFoundError:
    print(f"üî• Error: Experimental PDB not found at: {experimental_pdb}")
    sys.exit(1)

# Get all C-alpha atoms as a dictionary, keyed by residue number
try:
    ref_ca_dict = {
        atom.get_parent().id[1]: atom
        for atom in ref_structure[0].get_atoms()
        if atom.name == "CA" and atom.get_parent().id[0] == ' ' # Ensure it's a standard residue
    }
except Exception as e:
    print(f"üî• Error parsing reference PDB: {e}")
    print("   Make sure it's a valid PDB file.")
    sys.exit(1)

if not ref_ca_dict:
    print(f"üî• Error: No standard C-alpha atoms (CA) found in {experimental_pdb}")
    sys.exit(1)

print(f"Loaded {len(ref_ca_dict)} C-alpha atoms from reference.")

# --- Save a copy of the reference PDB to the aligned folder ---
print(f"\nSaving reference structure to {aligned_pdb_folder}...")
io.set_structure(ref_structure)
ref_output_name = f"{Path(experimental_pdb).stem}_ref.pdb"
ref_output_path = os.path.join(aligned_pdb_folder, ref_output_name)
io.save(ref_output_path)
print(f"  Saved: {ref_output_path}")

# --- 2. Load CSV File ---
csv_path = os.path.join(predictions_dir, csv_output_file)
try:
    df = pd.read_csv(csv_path)
except FileNotFoundError:
    print(f"üî• Error: CSV file not found at: {csv_path}")
    print("   Please run the previous cell to generate the CSV.")
    sys.exit(1)

# --- 3. Parse alignment residues if needed ---
alignment_residue_set = set()
if alignment_mode == "Specific Residues":
    alignment_residue_set = parse_residue_list(residue_list_to_align)
    if not alignment_residue_set:
        print(f"üî• Error: 'Specific Residues' mode selected, but no valid residues found in '{residue_list_to_align}'.")
        sys.exit(1)
    print(f"Running in 'Specific Residues' mode. Aligning on {len(alignment_residue_set)} specified residues.")
elif alignment_mode == "Iterative Exclusion":
    print(f"Running in 'Iterative Exclusion' mode (Cutoff={iterative_rmsd_cutoff} √Ö, Max Cycles={iterative_max_cycles}).")
else:
    print("Running in 'All Atoms' mode.")

# --- 4. Loop Through Predictions and Calculate RMSD ---
rmsd_list = []
aligned_residues_list = []
super_imposer = Superimposer()

print("\nCalculating RMSD and saving aligned structures...")
for index, row in df.iterrows():
    jobname = row['sequence_name']

    # Find the corresponding PDB file
    pdb_pattern = os.path.join(predictions_dir, f"{jobname}_unrelaxed_rank_001_*.pdb")
    pdb_files = glob.glob(pdb_pattern)

    if not pdb_files:
        print(f"‚ö†Ô∏è Warning: No PDB file found for {jobname}. Skipping.")
        rmsd_list.append(np.nan)
        aligned_residues_list.append(np.nan)
        continue

    predicted_pdb_path = pdb_files[0]

    try:
        # Load the predicted structure and get its C-alpha dict
        sample_structure = pdb_parser.get_structure(jobname, predicted_pdb_path)
        sample_ca_dict = {
            atom.get_parent().id[1]: atom
            for atom in sample_structure[0].get_atoms()
            if atom.name == "CA" and atom.get_parent().id[0] == ' '
        }

        if not sample_ca_dict:
            print(f"‚ö†Ô∏è Warning: No C-alpha atoms found in {jobname}. Skipping.")
            rmsd_list.append(np.nan)
            aligned_residues_list.append(np.nan)
            continue

        # --- Start Alignment Logic ---
        rmsd_val = np.nan
        num_aligned = 0

        if alignment_mode == "Specific Residues":
            common_res_ids = alignment_residue_set & set(ref_ca_dict.keys()) & set(sample_ca_dict.keys())
            if len(common_res_ids) < len(alignment_residue_set):
                print(f"  Info for {jobname}: Using {len(common_res_ids)} common residues out of {len(alignment_residue_set)} requested.")

            if not common_res_ids:
                print(f"  ‚ö†Ô∏è Warning: No common residues for alignment in {jobname}. Skipping.")
                rmsd_list.append(np.nan)
                aligned_residues_list.append(0)
                continue

            ref_atoms = [ref_ca_dict[res_id] for res_id in common_res_ids]
            sample_atoms = [sample_ca_dict[res_id] for res_id in common_res_ids]
            num_aligned = len(common_res_ids)

            super_imposer.set_atoms(ref_atoms, sample_atoms)
            # Apply transformation to the *entire* structure
            super_imposer.apply(sample_structure[0].get_atoms())
            rmsd_val = super_imposer.rms

            print_msg = f"  ‚úÖ {jobname}: RMSD = {rmsd_val:.3f} √Ö (on {num_aligned} specified residues)"


        elif alignment_mode == "Iterative Exclusion":
            current_res_ids = set(ref_ca_dict.keys()) & set(sample_ca_dict.keys())

            for i in range(iterative_max_cycles):
                if not current_res_ids:
                    print(f"  üî• Error for {jobname}: No atoms left to align during iteration.")
                    rmsd_val = np.nan
                    break

                ref_atoms_subset = [ref_ca_dict[res_id] for res_id in current_res_ids]
                sample_atoms_subset = [sample_ca_dict[res_id] for res_id in current_res_ids]

                super_imposer.set_atoms(ref_atoms_subset, sample_atoms_subset)
                super_imposer.apply(sample_structure[0].get_atoms())
                rmsd_val = super_imposer.rms

                new_res_ids = set()
                for res_id in current_res_ids:
                    dist = ref_ca_dict[res_id] - sample_ca_dict[res_id]
                    if dist < iterative_rmsd_cutoff:
                        new_res_ids.add(res_id)

                if len(new_res_ids) == len(current_res_ids):
                    print_msg = f"  ‚úÖ {jobname}: Converged. RMSD = {rmsd_val:.3f} √Ö (on {len(new_res_ids)} core atoms)"
                    num_aligned = len(new_res_ids)
                    break

                print(f"    Iter {i+1} for {jobname}: RMSD={rmsd_val:.3f} ({len(current_res_ids)} atoms) -> Removing {len(current_res_ids) - len(new_res_ids)} outliers.")
                current_res_ids = new_res_ids
            else:
                print_msg = f"  ‚ö†Ô∏è {jobname}: Max iterations reached. Final RMSD = {rmsd_val:.3f} √Ö (on {len(current_res_ids)} core atoms)"
                num_aligned = len(current_res_ids)

            rmsd_list.append(rmsd_val)
            aligned_residues_list.append(num_aligned)

        else: # "All Atoms" mode (default)
            common_res_ids = set(ref_ca_dict.keys()) & set(sample_ca_dict.keys())

            ref_atoms = [ref_ca_dict[res_id] for res_id in common_res_ids]
            sample_atoms = [sample_ca_dict[res_id] for res_id in common_res_ids]
            num_aligned = len(common_res_ids)

            if len(ref_atoms) != len(ref_ca_dict) or len(sample_atoms) != len(sample_ca_dict):
                 print(f"  Info for {jobname}: Found {num_aligned} common atoms for alignment.")

            super_imposer.set_atoms(ref_atoms, sample_atoms)
            # Apply transformation to the *entire* structure
            super_imposer.apply(sample_structure[0].get_atoms())
            rmsd_val = super_imposer.rms

            print_msg = f"  ‚úÖ {jobname}: RMSD = {rmsd_val:.3f} √Ö (on {num_aligned} atoms)"

        # --- Save the aligned structure ---
        if not np.isnan(rmsd_val):
            output_filename = os.path.join(aligned_pdb_folder, f"{jobname}_aligned.pdb")
            io.set_structure(sample_structure)
            io.save(output_filename)
            print(print_msg + f" -> Saved to {output_filename}")
        else:
            print(print_msg) # Print error/warning message from loop

        rmsd_list.append(rmsd_val)
        aligned_residues_list.append(num_aligned)

    except Exception as e:
        print(f"üî• Error processing {jobname}: {e}")
        rmsd_list.append(np.nan)
        aligned_residues_list.append(np.nan)

# --- 5. Update DataFrame and Save ---
df['rmsd_to_exp (√Ö)'] = rmsd_list
df['rmsd_aligned_residues'] = aligned_residues_list
df = df.sort_values("rmsd_to_exp (√Ö)", ascending=True)

# Save the updated CSV
df.to_csv(csv_path, index=False)

print(f"\n‚úÖ Successfully calculated RMSD, saved aligned PDBs to '{aligned_pdb_folder}', and updated {csv_path}.")
print(f"   Mode used: {alignment_mode}")
print("\n--- Results (Sorted by RMSD) ---")
print(df.head())

In [None]:
#@title 3) Plot pLDDT, pTM, and RMSD (Interactive, Wide)
#@markdown ---
#@markdown ### 1. File Locations
#@markdown ---
#@markdown Folder where your CSV file is located:
predictions_dir = "predictions" #@param {type:"string"}
#@markdown ---
#@markdown Name of the CSV file to read:
csv_output_file = "confidence_metrics.csv" #@param {type:"string"}
#@markdown ---
#@markdown **Name of your native/reference sequence:**
#@markdown (This must match the 'sequence_name' in the CSV exactly for red coloring)
native_sequence_name = "protein_native" #@param {type:"string"}
#@markdown ---
#@markdown ### 2. Output
#@markdown ---
#@markdown Name of the interactive plot file to save (must be .json):
plot_output_file_json = "all_metrics_plot.json" #@param {type:"string"}
#@markdown ---

import os
import pandas as pd
import numpy as np
import altair as alt
import sys

# --- 1. Load Data ---
csv_path = os.path.join(predictions_dir, csv_output_file)
try:
    df = pd.read_csv(csv_path)
except FileNotFoundError:
    print(f"üî• Error: CSV file not found at: {csv_path}")
    print("   Please run the previous cells to generate and update the CSV.")
    sys.exit(1)

# Check if required columns exist
required_cols = ['sequence_name', 'avg_plddt', 'ptm', 'rmsd_to_exp (√Ö)']
if not all(col in df.columns for col in required_cols):
    print(f"üî• Error: The CSV file is missing one or more required columns.")
    print(f"   Required: {required_cols}")
    print(f"   Found: {list(df.columns)}")
    sys.exit(1)

# --- 2. Prepare Data for Plotting ---

# --- NEW: Divide pLDDT by 100 ---
print("Applying pLDDT / 100 transformation...")
df['avg_plddt'] = df['avg_plddt'] / 100.0
# ---

# Create the 'model_type' column for coloring
if native_sequence_name not in df['sequence_name'].values:
    print(f"‚ö†Ô∏è Warning: Native sequence name '{native_sequence_name}' not found in CSV.")
    print("   All points will be colored orange.")
    df['model_type'] = "Designed"
else:
    df['model_type'] = np.where(
        df['sequence_name'] == native_sequence_name,
        'Native',
        'Designed'
    )

# "Melt" the DataFrame from wide to long format
df_melted = df.melt(
    id_vars=['sequence_name', 'model_type'],
    value_vars=['avg_plddt', 'ptm', 'rmsd_to_exp (√Ö)'],
    var_name='Metric',
    value_name='Value'
)

# --- UPDATED: Clean up metric names for better plot titles ---
df_melted['Metric'] = df_melted['Metric'].replace({
    'avg_plddt': 'pLDDT / 100',  # <-- Title changed here
    'ptm': 'pTM Score',
    'rmsd_to_exp (√Ö)': 'RMSD (√Ö)'
})

# --- 3. Create the Interactive Altair Plot ---
print(f"Generating interactive plot...")

# Define the custom color scale
color_scale = alt.Scale(domain=['Native', 'Designed'],
                        range=['red', 'orange'])

# Create the base chart
base = alt.Chart(df_melted).mark_circle(size=80, opacity=0.7).encode(
    # X-axis: Native vs. Designed (no title, labels at bottom)
    x=alt.X('model_type:N', title=None, axis=alt.Axis(labels=True, ticks=False, title="")),

    # Y-axis: The metric's value
    y=alt.Y('Value:Q', title='Value'),

    # Color based on type
    color=alt.Color('model_type:N', scale=color_scale, legend=alt.Legend(title="Model Type")),

    # Show this information on hover
    tooltip=[
        alt.Tooltip('sequence_name:N', title='Model'),
        alt.Tooltip('Metric:N', title='Metric'),
        alt.Tooltip('Value:Q', title='Value', format='.3f') # Use .3f for 0-1 scale
    ]
).properties(
    # --- NEW: Make each plot wider ---
    width=100
).interactive() # Make the chart interactive (zoom/pan)

# Create the final faceted chart
chart = base.facet(
    # Create one column for each "Metric".
    # The header for each facet will be the metric's name
    column=alt.Column('Metric:N', header=alt.Header(
        titleOrient="top",
        labelOrient="top"
    ))
).resolve_scale(
    # Make the Y-axis independent for each plot
    y='independent'
)

# --- 4. Save and Display the Plot ---
json_path = os.path.join(predictions_dir, plot_output_file_json)
chart.save(json_path)

print(f"‚úÖ Successfully saved interactive plot to: {json_path}")

# Display the chart in the Colab output
chart

In [None]:
#@title 4) Display the Aligned 3D Structure {run: "auto"}
import py3Dmol
import glob
import matplotlib.pyplot as plt
from colabfold.colabfold import plot_plddt_legend
from colabfold.colabfold import pymol_color_list, alphabet_list
import sys # Added for error checking
from pathlib import Path

#@markdown ### 1. PDB Location
#@markdown ---
#@markdown Folder where the aligned structures were saved:
aligned_structures_folder = "aligned" #@param {type:"string"}
#@markdown ---
#@markdown **Jobname of the sequence to display:**
#@markdown (e.g., "1LVM_A_native", "1LVM_A_sample1")
jobname = "protein_sample1" #@param {type:"string"}
#@markdown ---
#@markdown ### 2. Display Options
#@markdown ---
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
#@markdown ---
#@markdown **Overlay the reference structure?**
show_reference = True #@param {type:"boolean"}
#@markdown Path to the saved reference PDB:
reference_pdb_path = "aligned/protein_ref.pdb" #@param {type:"string"}
#@markdown ---

# --- Find the aligned PDB file ---
pdb_pattern = f"{aligned_structures_folder}/{jobname}_aligned.pdb"
pdb_file_list = glob.glob(pdb_pattern)

def show_pdb(
    predicted_pdb_path,
    show_reference=False,
    reference_pdb_path=None,
    show_sidechains=False,
    show_mainchains=False,
    color="lDDT"
):

    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)

    # --- 1. Add Reference Structure (if requested) ---
    if show_reference:
        try:
            view.addModel(open(reference_pdb_path,'r').read(),'pdb')
            # Style the *first* model (index 0) as gray
            view.setStyle({'model': 0}, {'cartoon': {'color': 'gray'}})
        except FileNotFoundError:
            print(f"‚ö†Ô∏è Warning: Reference PDB not found at {reference_pdb_path}. Skipping.")

    # --- 2. Add Predicted Structure ---
    view.addModel(open(predicted_pdb_path,'r').read(),'pdb')
    # Style the *last added* model (index -1)
    model_style = {'model': -1}

    if color == "lDDT":
        view.setStyle(model_style, {'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
    elif color == "rainbow":
        view.setStyle(model_style, {'cartoon': {'color':'spectrum'}})
    elif color == "chain":
        # Simple chain coloring
        view.setStyle(model_style, {'cartoon': {'color':'chain'}})

    if show_sidechains:
        BB = ['C','O','N']
        view.addStyle({'and':[model_style, {'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                            {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[model_style, {'resn':"GLY"},{'atom':'CA'}]},
                            {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[model_style, {'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                            {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    if show_mainchains:
        BB = ['C','O','N','CA']
        view.addStyle({'and':[model_style, {'atom':BB}]},
                            {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

    view.zoomTo()
    return view

# --- Display the structure ---
if not pdb_file_list:
    print(f"üî• Error: Could not find aligned PDB file.")
    print(f"   Searched for pattern: {pdb_pattern}")
    print(f"   Please check 'aligned_structures_folder' and 'jobname'.")
    print(f"   (Did you run the RMSD script to generate the aligned files?)")
else:
    pdb_to_show = pdb_file_list[0]
    print(f"Displaying: {pdb_to_show}")
    if show_reference:
        print(f"Overlaying: {reference_pdb_path}")

    view = show_pdb(
        pdb_to_show,
        show_reference=show_reference,
        reference_pdb_path=reference_pdb_path,
        show_sidechains=show_sidechains,
        show_mainchains=show_mainchains,
        color=color
    )
    view.show()

    if color == "lDDT":
        plot_plddt_legend().show()