<a href="https://colab.research.google.com/github/sokrypton/af2bind/blob/main/af2bind_large_pdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### AF2BIND: Prediction of ligand-binding sites using AlphaFold2

AF2BIND is a simple and fast notebook that runs inference on the output obtained from [AlphaFold2](https://github.com/deepmind/alphafold).

<!--<img src="https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png" width="300">.-->

<figure>
<img src='https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png'  width="300" height="150"/>
</figure>



For more details see preprint:

**AF2BIND: Predicting ligand-binding sites using the pair representation of AlphaFold2**
* Artem Gazizov, Anna Lian, Casper Alexander Goverde, Sergey Ovchinnikov, Nicholas F. Polizzi
* https://doi.org/10.1101/2023.10.15.562410


In [None]:
%%time
#@title Install AlphaFold2 (~2 mins)
#@markdown Please execute this cell by pressing the *Play* button on
#@markdown the left.

!pip install bio
!pip install biopython

import os, time
if not os.path.isdir("params"):
  # get code
  print("installing ColabDesign")
  os.system("(mkdir params; apt-get install aria2 -qq; \
  aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \
  mkdir af2bind_params; \
  wget -qnc https://github.com/sokrypton/af2bind/raw/main/attempt_7_2k_lam0-03.zip; unzip attempt_7_2k_lam0-03.zip -d af2bind_params; \
  tar -xf alphafold_params_2021-07-14.tar -C params; touch params/done.txt )&")

  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

  # download params
  if not os.path.isfile("params/done.txt"):
    print("downloading params")
    while not os.path.isfile("params/done.txt"):
      time.sleep(5)

import gc
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

from colabdesign.af.alphafold.common import residue_constants
import pandas as pd
from google.colab import data_table
data_table._DEFAULT_FORMATTERS[float] = lambda x: f"{x:.3f}"
from IPython.display import display, HTML
import jax, pickle
import jax.numpy as jnp
from scipy.special import expit as sigmoid
import plotly.express as px

import py3Dmol
import matplotlib.pyplot as plt
from scipy.special import softmax
import copy
from colabdesign.shared.protein import renum_pdb_str
from colabdesign.af.alphafold.common import protein


aa_order = {v:k for k,v in residue_constants.restype_order.items()}

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v4.pdb")
    return f"AF-{pdb_code}-F1-model_v4.pdb"

def af2bind(outputs, mask_sidechains=True, seed=0):
  pair_A = outputs["representations"]["pair"][:-20,-20:]
  pair_B = outputs["representations"]["pair"][-20:,:-20].swapaxes(0,1)
  pair_A = pair_A.reshape(pair_A.shape[0],-1)
  pair_B = pair_B.reshape(pair_B.shape[0],-1)
  x = np.concatenate([pair_A,pair_B],-1)

  # get params
  if mask_sidechains:
    model_type = f"split_nosc_pair_A_split_nosc_pair_B_{seed}"
  else:
    model_type = f"split_pair_A_split_pair_B_{seed}"
  with open(f"af2bind_params/attempt_7_2k_lam0-03/{model_type}.pickle","rb") as handle:
    params_ = pickle.load(handle)
  params_ = dict(**params_["~"], **params_["linear"])
  p = jax.tree_map(lambda x:np.asarray(x), params_)

  # get predictions
  x = (x - p["mean"]) / p["std"]
  x = (x * p["w"][:,0]) + (p["b"] / x.shape[-1])
  p_bind_aa = x.reshape(x.shape[0],2,20,-1).sum((1,3))
  p_bind = sigmoid(p_bind_aa.sum(-1))
  return {"p_bind":p_bind, "p_bind_aa":p_bind_aa}

In [None]:
#@title **Set up a pipeline for large PDBs 🧬 (~ 2 mins)**
#@markdown - Install Merizo and define the functions for subgraph search (optional)
import ipywidgets as widgets
from Bio import PDB
import os
import shutil
from collections import ChainMap
from Bio.PDB import PDBParser, PDBIO, Select
import networkx as nx
from itertools import combinations
import warnings
import csv
import re
import io

warnings.filterwarnings('ignore')

if not os.path.isdir("Merizo"):
  os.system("git  clone https://github.com/psipred/Merizo")
  os.system("pip install -r Merizo/requirements.txt")
  os.system("pip install torch torchvision torchaudio scipy matplotlib einops networkx rotary-embedding-torch natsort")



def setup_directory(directory_path):
    # Check if the directory exists
    if os.path.exists(directory_path):
        # If it exists, remove it
        shutil.rmtree(directory_path)

    # Create the directory
    os.makedirs(directory_path)
    print(f"## Directory '{directory_path}' created.")


def split_pdb_by_chains(input_pdb, chain_ids):

  """
  Splits a PDB file into separate files based on provided chain IDs.

  Parameters:
  input_pdb (str): The path to the input PDB file.
  chain_ids (list): A list of chain IDs (e.g. ['A', 'B']).

  Returns:
  None: Writes new PDB files for each chain.
  """

  single_chain_pdbs_list=[]

  chain_ids = chain_ids.split(',')
  save_dir=f"{SINGLE_CHAINS_PATH}/{input_pdb}"
  setup_directory(save_dir)

  # Read the input PDB file
  with open(input_pdb+".pdb", 'r') as pdb_file:
      pdb_lines = pdb_file.readlines()

  # Create a dictionary to store the lines for each chain
  chain_data = {chain: [] for chain in chain_ids}

  # Iterate through each line and assign to the appropriate chain
  for line in pdb_lines:
      if line.startswith(('ATOM', 'HETATM')):  # Only process ATOM and HETATM records
          chain_id = line[21]  # Chain identifier is at column 22 (index 21 in 0-based index)
          if chain_id in chain_data:
              chain_data[chain_id].append(line)

  # Write out each chain to a new file
  for chain in chain_ids:
      output_filename = f"{save_dir}/chain_{chain}.pdb"
      with open(output_filename, 'w') as output_file:
          output_file.writelines(chain_data[chain])
      print(f"Chain {chain} written to {output_filename}")
      single_chain_pdbs_list.append(output_filename)


  return single_chain_pdbs_list


def extract_residues(pdb_path, residue_range, chain_id, output_name,domain_n):
    # Create a PDB parser
    parser = PDB.PDBParser(QUIET=True)

    # Parse the structure
    structure = parser.get_structure("protein", pdb_path)

    # Prepare a set to store selected residues
    selected_residues = []

    # Convert the residue range to a list of integers
    ranges = residue_range.split('_')
    residue_ids = set()

    for r in ranges:
        # Check if it's a single residue or a range
        if '-' in r:
            start, end = map(int, r.split('-'))
            residue_ids.update(range(start, end + 1))
        else:
            # It's a single residue
            residue_ids.add(int(r))

    # Iterate through the structure to collect selected residues
    for model in structure:
        for chain in model:
            # Check if the current chain matches the specified chain ID
            if chain.id == chain_id:
                for residue in chain:
                    # Check if the residue ID is in the selected range
                    if residue.id[1] in residue_ids:  # residue.id[1] is the residue number
                        selected_residues.append(residue)

    # Write the selected residues to a new PDB file
    io = PDB.PDBIO()
    io.set_structure(structure)

    # Create a new structure to save selected residues
    new_structure = PDB.Structure.Structure('selected_residues')
    new_model = PDB.Model.Model(0)
    new_chain = PDB.Chain.Chain(chain_id)  # Use the specified chain ID

    for residue in selected_residues:
        new_chain.add(residue)

    new_model.add(new_chain)
    new_structure.add(new_model)

    # Save to a new PDB file with the specified output name
    output_file = output_name  # Use the specified output name
    io.set_structure(new_structure)
    io.save(output_file)

    print("")
    print(f'## domain {chain_id}{domain_n} saved to {output_file}')
    print("")



def get_domains(target_pdb, single_chain_pdbs_list):
    my_domains_pdb_path = {}
    my_domains_res_range = {}

    # Run merizo on each single chain
    for chain_path in single_chain_pdbs_list:
        print("")
        print(f"{target_pdb}, chain {chain_path} is being processed")

        chain = chain_path.split("/")[-1].split("chain_")[-1][0]

        # Run domain separation for a particular chain
        !python Merizo/predict.py -i {chain_path} --pdb_chain {chain} --save_domains > /dev/null

        # Get domain info path for a particular chain
        domain_info_path = os.path.join(SINGLE_CHAINS_PATH, target_pdb, f"chain_{chain}_merizo_v2.domains")

        with open(domain_info_path, 'r') as file:
            # For each domain create a pdb
            for line in file:

                domain_n = line.strip().split("\t")[-2]
                domain_r = line.strip().split("\t")[-1]

                current_domain_out_path = os.path.join(DOMAINS_PATH, target_pdb, f"chain_{chain}_domain_{domain_n}.pdb")
                extract_residues(chain_path, domain_r, chain, current_domain_out_path,domain_n)
                key_n = chain + domain_n
                my_domains_pdb_path[key_n] = current_domain_out_path

                my_arr = []
                for d in domain_r.split("_"):
                    my_arr.append([int(n) for n in d.split("-")])

                my_domains_res_range[key_n] = {"res": my_arr}

    return my_domains_res_range, my_domains_pdb_path

def get_heavy_atom_coordinates(pdb_filename):
    """
    Retrieve coordinates of all heavy atoms from a PDB file.
    """
    parser = PDBParser()
    structure = parser.get_structure("protein", pdb_filename)
    coords = {}
    for model in structure:
        for chain in model:
            for residue in chain:
                res_id = residue.get_id()
                if res_id[0] == ' ':
                    res_coords = []
                    for atom in residue:
                        if atom.element != 'H':  # Exclude hydrogen atoms
                            res_coords.append(atom.get_coord())
                    coords[(chain.id, res_id[1])] = res_coords
    return coords


def calculate_distance(coord1, coord2):
    """
    Calculate the Euclidean distance between two 3D coordinates.
    """
    return sum((c1 - c2) ** 2 for c1, c2 in zip(coord1, coord2)) ** 0.5



def find_contacting_domains(all_chains_domains, input_pdb_file, threshold=5.0, min_contact_residues=5):
    """
    Find contacting domains for each domain based on distance threshold and residue count.

    Parameters:
    - all_chains_domains: List of dictionaries mapping domains to residue ranges for each chain.
    - input_pdb_file: Path to the input PDB file containing 3D coordinates.
    - threshold: Distance threshold to consider residues as in contact.
    - min_contact_residues: Minimum number of contacting residues required to consider two domains in contact.

    Returns:
    - contacting_domains: Dictionary where keys are domain identifiers and values are sets of contacting domains.
    """

    # Get heavy atom coordinates from PDB file
    coords = get_heavy_atom_coordinates(input_pdb_file)

    # Initialize contacting domains dictionary
    contacting_domains = {domain: set() for chain_domains in all_chains_domains for domain in chain_domains}

    # Flatten list of domain dictionaries into a single dictionary
    chain_domains = dict(ChainMap(*all_chains_domains))

    # Iterate over each pair of domains
    for domain1, data1 in chain_domains.items():
        for domain2, data2 in chain_domains.items():
            if domain1 != domain2:
                contact_count = 0  # Count of contacting residues between domain1 and domain2

                # Check each residue pair from domain1 and domain2
                for res_interval_1 in data1['res']:
                    for res_interval_2 in data2['res']:
                        for res_1 in range(res_interval_1[0], res_interval_1[1] + 1):
                            if (domain1[:-1], res_1) not in coords:
                                continue
                            for res_2 in range(res_interval_2[0], res_interval_2[1] + 1):
                                if (domain2[:-1], res_2) not in coords:
                                    continue

                                res1_coords = coords[(domain1[:-1], res_1)]
                                res2_coords = coords[(domain2[:-1], res_2)]

                                # Check distances between all pairs of coordinates
                                for coord1 in res1_coords:
                                    for coord2 in res2_coords:
                                        distance = calculate_distance(coord1, coord2)
                                        if distance <= threshold:
                                            contact_count += 1
                                            break  # No need to check other coordinates once contact is found
                                    if contact_count >= min_contact_residues:
                                        break  # Break if minimum contact residues threshold is reached
                                if contact_count >= min_contact_residues:
                                    break  # Break if minimum contact residues threshold is reached
                            if contact_count >= min_contact_residues:
                                break  # Break if minimum contact residues threshold is reached
                        if contact_count >= min_contact_residues:
                            break  # Break if minimum contact residues threshold is reached

                # If enough contacts are found, mark domains as contacting each other
                if contact_count >= min_contact_residues:
                    contacting_domains[domain1].add(domain2)
                    contacting_domains[domain2].add(domain1)

    return contacting_domains



def find_residues_contacting_domains(all_chains_domains, input_pdb_file, threshold=5.0):
    """
    Find contacting residues between domains for each domain based on a distance threshold.
    """
    pdb_filename = input_pdb_file
    coords = get_heavy_atom_coordinates(pdb_filename)

    # Initialize contacting residues dictionary for each domain
    contacting_residues = {domain: [] for chain_domains in all_chains_domains for domain in chain_domains}

    # Create a mapping for chain domains
    chain_domains = dict(ChainMap(*all_chains_domains))

    # Track domain pairs that have already been compared
    compared_pairs = set()

    # Iterate over all domain pairs (domain1 and domain2)
    for domain1, data1 in chain_domains.items():
        for domain2, data2 in chain_domains.items():
            # Skip if comparing the same domain or if this domain pair has already been compared
            if domain1 == domain2 or (domain1, domain2) in compared_pairs or (domain2, domain1) in compared_pairs:
                continue

            # Mark this domain pair as compared
            compared_pairs.add((domain1, domain2))

            # Iterate over residue intervals in both domains
            for res_interval_1 in data1['res']:
                for res_interval_2 in data2['res']:
                    for res_1 in range(res_interval_1[0], res_interval_1[1] + 1):
                        # Check if the coordinates for residue res_1 in domain1 exist
                        if (domain1[:-1], res_1) not in coords:
                            continue
                        for res_2 in range(res_interval_2[0], res_interval_2[1] + 1):
                            # Check if the coordinates for residue res_2 in domain2 exist
                            if (domain2[:-1], res_2) not in coords:
                                continue
                            # Get coordinates for both residues
                            res1_coords = coords[(domain1[:-1], res_1)]
                            res2_coords = coords[(domain2[:-1], res_2)]

                            # Check if any atom pairs between res1 and res2 are within the threshold distance
                            for coord1 in res1_coords:
                                for coord2 in res2_coords:
                                    distance = calculate_distance(coord1, coord2)
                                    if distance <= threshold:
                                        # Save both residues as contacting residues for their respective domains
                                        contacting_residues[domain1].append(res_1)
                                        contacting_residues[domain2].append(res_2)
                                        break  # No need to check further atom pairs once a contact is found

    # Remove duplicate residue indices and sort them for each domain
    for domain in contacting_residues:
        contacting_residues[domain] = sorted(set(contacting_residues[domain]))

    return contacting_residues


def get_pdb_length(pdb_file_path):
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('structure', pdb_file_path)

    residues = set()
    for model in structure:
        for chain in model:
            for residue in chain:
                residues.add((chain.id, residue.id[1]))

    return len(residues)


def process_graph(contacting_domains, all_chains_domains_paths, max_sum, threshold):
    # Create a graph from contacting_domains
    G = nx.Graph()

    # Add nodes with values
    for node in contacting_domains:
        G.add_node(node, value=get_pdb_length(all_chains_domains_paths[node]))

    # Add edges
    for node, neighbors in contacting_domains.items():
        for neighbor in neighbors:
            G.add_edge(node, neighbor)

    def get_all_connected_subgraphs(graph):
        nodes = list(graph.nodes())
        subgraphs = []

        # Generate all possible subgraphs
        for size in range(1, len(nodes) + 1):
            for node_combination in combinations(nodes, size):
                subgraph = graph.subgraph(node_combination)
                if nx.is_connected(subgraph):
                    subgraphs.append(subgraph)

        return subgraphs

    def filter_and_sort_subgraphs(subgraphs, max_sum):
        valid_subgraphs = []
        seen_subgraphs = []

        # Sort subgraphs by their size and total value in descending order
        subgraphs.sort(key=lambda sg: (len(sg.nodes()), sum(nx.get_node_attributes(sg, 'value').values())), reverse=True)

        for subgraph in subgraphs:
            node_values = nx.get_node_attributes(subgraph, 'value')
            total_value = sum(node_values.values())

            subgraph_nodes = set(subgraph.nodes())

            # Check if subgraph_nodes is a subset of any seen subgraph nodes
            if any(subgraph_nodes <= seen for seen in seen_subgraphs):
                continue

            if total_value <= max_sum:
                valid_subgraphs.append((subgraph, total_value))
                seen_subgraphs.append(subgraph_nodes)

        return valid_subgraphs


    def find_pairs_with_value_greater_than(graph, threshold):
        pairs = []

        for node1, node2 in graph.edges():
            value1 = graph.nodes[node1]['value']
            value2 = graph.nodes[node2]['value']
            total_value = value1 + value2

            if total_value > threshold:
                pairs.append((node1, node2))

        return pairs

    def find_single_domains_with_value_greater_than(graph, threshold):
        nodes = list(graph.nodes())
        high_value_nodes = []

        for node in nodes:
            value = graph.nodes[node]['value']
            if value > threshold:
                high_value_nodes.append(node)

        return high_value_nodes

    # Get all connected subgraphs
    all_connected_subgraphs = get_all_connected_subgraphs(G)

    # Filter subgraphs by the sum of their node attributes and sort them
    filtered_sorted_subgraphs = filter_and_sort_subgraphs(all_connected_subgraphs, max_sum)

    my_batch = []
    contacting_domain_combinations=[]
    # Collect all valid subgraphs
    for subgraph, total_value in filtered_sorted_subgraphs:
        nodes = subgraph.nodes(data=False)

        my_batch.append(list(nodes))

        print("")
        print("Subgraph has following domains:", nodes)
        print("Total length of subgraph:", total_value)
        print("")

    # Find pairs with value greater than the threshold
    pairs_greater_th = find_pairs_with_value_greater_than(G, threshold)

    for pairs in pairs_greater_th:
        my_batch.append(list(pairs))
        print("")
        print("Pairs with length greater than 300 res. :",pairs)
        print("")

    # Find single domains with value greater than the threshold
    single_greater_th = find_single_domains_with_value_greater_than(G, threshold)
    print("")
    print("Single domains greater than 300 res.:",single_greater_th)
    print("")

    print(my_batch)
    for domains in my_batch:
        contacting_domain_combinations.append("_".join(domains))

    return contacting_domain_combinations, single_greater_th

def copy_pdb(true_pdb_path, output_pdb_path):
    """
    Copies a PDB file from true_pdb_path to output_pdb_path.

    Parameters:
    true_pdb_path (str): The path to the source PDB file.
    output_pdb_path (str): The path to the destination PDB file.
    """
    try:
        shutil.copyfile(true_pdb_path, output_pdb_path)
        print(f"File copied from {true_pdb_path} to {output_pdb_path}.")
    except Exception as e:
        print(f"An error occurred: {e}")

def merge_pdb_files(output_file, *input_files):
    # Initialize a PDB writer
    pdb_writer = PDB.PDBIO()
    # Create a structure object to hold the merged structure
    merged_structure = PDB.Structure.Structure('merged_structure')

    # Create a model to hold the merged data
    model = PDB.Model.Model(0)
    merged_structure.add(model)

    # Dictionary to accumulate residues for each chain ID
    chains_dict = {}

    # Iterate over each input PDB file
    for pdb_file in input_files:
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('temp_structure', pdb_file)

        # Extract chains from the current structure
        for chain in structure.get_chains():
            chain_id = chain.id

            # Initialize a new list for this chain if it doesn't exist
            if chain_id not in chains_dict:
                chains_dict[chain_id] = []

            # Collect all residues for the current chain
            for residue in chain.get_residues():
                # Create a copy of the residue and add it to the list
                new_residue = residue.copy()
                chains_dict[chain_id].append(new_residue)

    # Now create chains in the model based on the collected residues
    for chain_id, residues in chains_dict.items():
        # Create a new chain
        new_chain = PDB.Chain.Chain(chain_id)
        for residue in residues:
            new_chain.add(residue)

        # Add the new chain to the model
        model.add(new_chain)

    # Set the merged structure to the writer and save it to the output file
    pdb_writer.set_structure(merged_structure)
    pdb_writer.save(output_file)

def save_set_to_csv(s, filename):
    """
    Save a set of integers to a CSV file.

    Parameters:
    s (set): The set of integers to save.
    filename (str): The name of the CSV file to save the data in.
    """
    # Convert set to a sorted list (optional, if you want sorted output)
    data = sorted(s)

    # Open the file in write mode
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)

        # Write each number on a new row
        for number in data:
            writer.writerow([number])


def save_dict_to_csv(data, filename):
    """
    Save a dictionary to a CSV file with keys as columns and list elements as rows.

    Parameters:
    data (dict): The dictionary to save. Keys should be column names and values should be lists of row values.
    filename (str): The name of the file to save the CSV as.
    """
    # Find the maximum length of the lists
    max_len = max(len(v) for v in data.values())

    # Prepare the rows
    rows = []
    for i in range(max_len):
        row = []
        for key in data:
            row.append(data[key][i] if i < len(data[key]) else '')
        rows.append(row)

    # Write to CSV
    with open(filename, 'w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(data.keys())  # Write the header
        writer.writerows(rows)        # Write the data rows



def create_or_replace_dir(directory):
    if os.path.exists(directory):  # Check if directory exists
        print(f"Directory '{directory}' already exists. Deleting and creating again.")
        try:
            shutil.rmtree(directory)  # Remove existing directory and its contents
        except OSError as e:
            print(f"Error: {directory} : {e.strerror}")
        os.makedirs(directory)  # Create directory
    else:
        print(f"Directory '{directory}' does not exist. Creating...")
        os.makedirs(directory)  # Create directory


# Function to update B-factors and save the structure
def update_b_factors_and_save(structure, df, output_path):
    for model in structure:
        for chain in model:
            chain_id = chain.get_id()
            for residue in chain:
                res_id = residue.get_id()[1]
                # Check if both chain and residue match
                if ((df['chain'] == chain_id) & (df['resi'] == res_id)).any():
                    b_factor = df.loc[(df['chain'] == chain_id) & (df['resi'] == res_id), 'p(bind)'].values[0]
                    #print(f"Chain {chain_id} Residue {res_id}: Setting B-factor to {b_factor}")
                else:
                    b_factor = 0.0
                    #print(f"Chain {chain_id} Residue {res_id}: Setting B-factor to {b_factor} (default)")
                for atom in residue:

                    atom.set_bfactor(b_factor*100)

    # Save the updated structure
    class BFactorSelect(Select):
        def accept_atom(self, atom):
            return True

    pdb_io = PDBIO()
    pdb_io.set_structure(structure)
    pdb_io.save(output_path, BFactorSelect())
    print(f"Structure saved to {output_path}")

def remove_residues(input_pdb_path, output_pdb_path, residues_to_remove):
    """
    Remove specified residues from a PDB file and save the result to a new PDB file.

    Parameters:
    - input_pdb_path: Path to the input PDB file.
    - output_pdb_path: Path to save the modified PDB file.
    - residues_to_remove: Set of residue indices to be removed.
    """
    # Create a parser and structure object
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('structure', input_pdb_path)

    # Create a writer object
    io = PDB.PDBIO()
    io.set_structure(structure)

    # Define a function to filter out residues
    class ResidueSelector(PDB.Select):
        def __init__(self, residues_to_remove):
            self.residues_to_remove = residues_to_remove

        def accept_residue(self, residue):
            # Check if residue index is in the set of residues to remove
            return residue.get_id()[1] not in self.residues_to_remove

    # Create a ResidueSelector instance
    residue_selector = ResidueSelector(residues_to_remove)

    # Save the structure with selected residues
    io.save(output_pdb_path, residue_selector)


def af2bind_inference(pdb_filename, target_chain, mask_sidechains=True, mask_sequence=False, preds_pdb_path="",residues_ignore=[]):

    mask_sidechains=True
    mask_sequence=False

    print(f"\n# running af2bind on: {pdb_filename}")
    print(f"# chain(s): {target_chain}")

    clear_mem()
    
    af_model = mk_afdesign_model(protocol="binder", debug=True)
    af_model.prep_inputs(pdb_filename=pdb_filename,
                         chain=target_chain,
                         binder_len=20,
                         rm_target_sc=mask_sidechains,
                         rm_target_seq=mask_sequence)

    # split
    r_idx = af_model._inputs["residue_index"][-20] + (1 + np.arange(20)) * 50
    af_model._inputs["residue_index"][-20:] = r_idx.flatten()

    af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
    af_model.predict(verbose=False)

    o = af2bind(af_model.aux["debug"]["outputs"],
                mask_sidechains=mask_sidechains)
    pred_bind = o["p_bind"].copy()
    pred_bind_aa = o["p_bind_aa"].copy()

    #######################################################
    #labels = ["chain", "resi", "resn", "p(bind)"]
    labels = ["chain","resi","resn","p(bind)","arr_i"]
    data = []
    for i in range(af_model._target_len):
        c = af_model._pdb["idx"]["chain"][i]
        r = af_model._pdb["idx"]["residue"][i]
        a = aa_order.get(af_model._pdb["batch"]["aatype"][i], "X")
        p = pred_bind[i]
        #data.append([c, r, a, p])
        data.append([c,r,a,p,i])

    df = pd.DataFrame(data, columns=labels)
    df_sorted = df.sort_values("p(bind)", ascending=False, ignore_index=True).rename_axis('rank').reset_index()

    if (len(residues_ignore)>0):

        #print(residues_ignore)
        arr_i_list = df_sorted.loc[df_sorted['resi'].isin(residues_ignore), 'arr_i'].tolist()
        for index in arr_i_list:
            pred_bind[index] = 0

    preds_adj = pred_bind.copy()

    L = af_model._target_len
    aux = copy.deepcopy(af_model.aux["all"])
    aux["plddt"][:, :L] = preds_adj
    out_name = pdb_filename.split("/")[-1].split(".pdb")[0] + "_pred.pdb"
    #af_model.save_pdb(f"{preds_pdb_path}/{out_name}", aux={"all": aux})

    aux["atom_mask"][:,L:] = 0
    x = {k:[] for k in ["aatype",
                      "residue_index",
                      "atom_positions",
                      "atom_mask",
                      "b_factors"]}
    asym_id = []
    for i in range(af_model._target_len):
        for k in ["aatype","atom_mask"]: x[k].append(aux[k][0,i])

        x["atom_positions"].append(af_model._pdb["batch"]["all_atom_positions"][i])
        x["residue_index"].append(af_model._pdb["idx"]["residue"][i])
        x["b_factors"].append(x["atom_mask"][-1] * aux["plddt"][0,i] * 100.0)
        asym_id.append(af_model._pdb["idx"]["chain"][i])
    x = {k:np.array(v) for k,v in x.items()}

    # fix the chains
    (n,resnum_) = (0,None)
    pdb_lines = []
    for line in protein.to_pdb(protein.Protein(**x)).splitlines():
        if line[:4] == "ATOM":
          resnum = int(line[22:22+5])
          if resnum_ is None: resnum_ = resnum
          if resnum != resnum_:
            n += 1
            resnum_ = resnum
          pdb_lines.append("%s%s%4i%s" % (line[:21],asym_id[n],resnum,line[26:]))

    with open(f"{preds_pdb_path}/{out_name}","w") as handle:
        handle.write("\n".join(pdb_lines))


    return df_sorted

def convert_filenames(filename):
    # Use regular expression to extract the domain part from the filename
    match = re.search(r'_([A-Z]\d)(_[A-Z]\d)?\.pdb$', filename)

    if match:
        # If there's a second domain part, it's a combination
        if match.group(2):
            return ""
        else:
            return match.group(1)
    return ""

def get_low_bfactor_stretches(pdb_filename, bfactor_threshold=70.0, min_stretch_length=7):
    """
    Extract contiguous stretches of residues with B-factors below the threshold
    Parameters:
    pdb_filename (str): Path to the PDB file.
    bfactor_threshold (float): The B-factor threshold. Default is 70.0.
    min_stretch_length (int): Minimum length of contiguous stretches to be stored. Default is 7.

    Returns:
    list: List of lists where each sublist represents a contiguous stretch of residues with B-factors below the threshold.
    """
    # Read and parse the PDB file
    with open(pdb_filename, 'r') as f:
        pdb_content = f.read()

    pdb_io = io.StringIO(pdb_content)
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('structure', pdb_io)

    def get_low_bfactor_residues_from_chain(chain):
        residues = []
        for residue in chain:
            res_id = residue.get_id()[1]  # Get the residue sequence number
            b_factors = [atom.get_bfactor() for atom in residue]
            if any(b_factor < bfactor_threshold for b_factor in b_factors):
                residues.append(res_id)
            else:
                if len(residues) >= min_stretch_length:
                    yield residues
                residues = []

        if len(residues) >= min_stretch_length:
            yield residues

    def collect_from_terminal(chain, start_residue, direction):
        residues = []
        residue_dict = {residue.get_id()[1]: residue for residue in chain.get_residues()}
        current_residue = start_residue
        while current_residue:
            res_id = current_residue.get_id()[1]  # Get the residue sequence number
            b_factors = [atom.get_bfactor() for atom in current_residue]
            if any(b_factor < bfactor_threshold for b_factor in b_factors):
                residues.append(res_id)
            else:
                break
            # Move to the next residue
            next_res_id = res_id + 1 if direction == 'N' else res_id - 1
            current_residue = residue_dict.get(next_res_id, None)
        return residues

    low_bfactor_stretches = []

    # Iterate over all models, chains, and residues
    for model in structure:
        for chain in model:
            residues = list(chain.get_residues())
            if residues:
                # Collect residues starting from N-terminus and C-terminus
                n_terminal_stretch = collect_from_terminal(chain, residues[0], 'N')
                if len(n_terminal_stretch) >= min_stretch_length:
                    low_bfactor_stretches.append(n_terminal_stretch)

                c_terminal_stretch = collect_from_terminal(chain, residues[-1], 'C')
                if len(c_terminal_stretch) >= min_stretch_length:
                    low_bfactor_stretches.append(c_terminal_stretch)

                # Collect and process all stretches of residues with B-factors below threshold
                for stretch in get_low_bfactor_residues_from_chain(chain):
                    low_bfactor_stretches.append(stretch)

    return low_bfactor_stretches


def af2bind_pipeline(preds_path, target_chain, out_process_pdb_file, pdb_filepath, low_bfactor_residues):

    try:

        pdb_i=out_process_pdb_file["pdb_i"]

        # Inference: create folder for af2bind inference pdb
        preds_pdb_path = preds_path + "/" + out_process_pdb_file["pdb_i"]
        create_or_replace_dir(preds_pdb_path)
        print("")

        if(af2_struct):
          pdb_filepath_b_high=preds_pdb_path + "/" + out_process_pdb_file["pdb_i"] +  "_high_pLDDT.pdb"
          remove_residues(input_pdb_path=pdb_filepath, output_pdb_path=pdb_filepath_b_high, residues_to_remove=low_bfactor_residues)


        # Iterate through pdbs of single domains & combination of multiple domains
        all_df = []

        for pdb_filename in out_process_pdb_file["all_domains_af2bind_run"]:

            #remove low plddt streches and update the existing files
            remove_residues(input_pdb_path=pdb_filename, output_pdb_path=pdb_filename, residues_to_remove=flattened_low_bfactor_stretches)

            # Get domain name like "A1" from for example 3hlg_A_merizo_v2_01.pdb
            domain_name_single = convert_filenames(pdb_filename)
            #print("domain name ", domain_name_single)
            residues_ignore = []

            # Get contacting residues for single domains
            if domain_name_single != "":  # meaning if it is a single domain, since convert_filenames will output "" for contacting domains

                print(f"## single domain {domain_name_single}, contacting residues will be zeroed!")
                residues_ignore = out_process_pdb_file["contacting_residues"][domain_name_single]
                print("## ignored contacting residues: ",residues_ignore)


            df_sorted = af2bind_inference(pdb_filename, target_chain, mask_sidechains=True, mask_sequence=False, preds_pdb_path=preds_pdb_path, residues_ignore=residues_ignore)



            # Set contacting residues to 0 for single domains
            df_sorted.loc[df_sorted['resi'].isin(residues_ignore), 'p(bind)'] = 0

            all_df.append(df_sorted.copy())

            # Full domain name like e.g. 3hlg_A_merizo_v2_01.pdb
            domain_name = pdb_filename.split("/")[-1].split(".pdb")[0]





        # Lumping all the predictions
        concatenated_df = pd.concat(all_df, ignore_index=True)
        grouped = concatenated_df.groupby(['chain', 'resi'])
        max_pbind_df = grouped['p(bind)'].max().reset_index()
        sorted_pbind_df = max_pbind_df.sort_values(by='p(bind)', ascending=False)

        sorted_pbind_df.to_csv(f"{preds_pdb_path}/domains_lumped_{pdb_i}.csv", index=False)

        #metrics = calculate_rec_avpbind(sorted_pbind_df, binding_sites)
        #print("av_pbind ", metrics["av_pbind"])
        #print("recovery ", metrics["recovery"])

        print("n. res: ", len(sorted_pbind_df))

        pdb_len = len(df_sorted)

        # Save lumped pdb file
        pdb_parser = PDBParser(QUIET=True)
        if(af2_struct):
          structure = pdb_parser.get_structure('protein', pdb_filepath_b_high)
        else:
          structure = pdb_parser.get_structure('protein', out_process_pdb_file["input_pdb_file"])

        out_name = out_process_pdb_file["input_pdb_file"].split("/")[-1].split(".pdb")[0] + "_lumped_pred.pdb"
        update_b_factors_and_save(structure, sorted_pbind_df, f"{preds_pdb_path}/{out_name}")


        #save contacting residues
        save_dict_to_csv(data=out_process_pdb_file["contacting_residues"],filename=f'{preds_pdb_path}/contacting_residues.csv')


        #save removed low plddt streches
        save_set_to_csv(s=low_bfactor_residues,filename=f'{preds_pdb_path}/low_plddt_streches.csv')



        print("### end ####")
        print("")

    except ValueError as e:
        print(f"ValueError occurred: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")



def process_pdb_file(target_pdb, chain_ids):
    """
    Process the PDB file to extract domain-related information and create
    PDBs for contacting domains.

    Parameters:
        target_pdb (str): The target PDB file.
        chain_ids (list): List of chain IDs to process.

    Returns:
        dict: Contains information about processed PDBs and contacting residues.
    """

    # Step 1: Split PDB by chains
    print("## Splitting PDB by chains....")
    single_chain_pdbs_list = split_pdb_by_chains(input_pdb=target_pdb, chain_ids=chain_ids)

    # Step 2: Get domain information and residue ranges
    print("## Getting domain residue ranges and paths....")
    all_chains_domains_residues, all_chains_domains_paths = get_domains(target_pdb, single_chain_pdbs_list)
    all_chains_domains_residues = [all_chains_domains_residues]

    # Step 3: Find contacting domains
    print("## Finding contacting domains....")
    contacting_domains = find_contacting_domains(all_chains_domains_residues, f"{target_pdb}.pdb")
    print(f"Contacting domains: {contacting_domains}\n")

    # Step 4: Find contacting residues in domains
    print("## Finding contacting residues in contacting domains....")
    contacting_residues = find_residues_contacting_domains(all_chains_domains_residues, f"{target_pdb}.pdb")

    # Step 5: Process the contact graph and get domain subgraphs
    print("## Processing contact graph....")
    contacting_domain_combinations, single_greater_than_th = process_graph(
        contacting_domains,
        all_chains_domains_paths,
        MAX_SUM_RES_SUBGRAPH,
        THRESHOLD_SINGLE_PAIRS
    )

    # Step 6: Create PDBs for contacting domains
    print("## Creating PDBs for contacting domains....")
    all_domains_af2bind_run = create_contacting_pdbs(
        target_pdb,
        contacting_domain_combinations,
        all_chains_domains_paths,
        single_greater_than_th
    )

    print("Done!")

    return {
        "all_domains_af2bind_run": all_domains_af2bind_run,
        "pdb_i": target_pdb,
        "contacting_residues": contacting_residues,
        "input_pdb_file": f"{target_pdb}.pdb"
    }


def create_contacting_pdbs(target_pdb, contacting_domain_combinations, all_chains_domains_paths, single_greater_than_th):
    """
    Create PDBs by merging domains and adding single domains above a threshold.

    Parameters:
        target_pdb (str): The target PDB file.
        contacting_domain_combinations (list): List of contacting domain combinations.
        all_chains_domains_paths (dict): Mapping of domain names to PDB paths.
        single_greater_than_th (list): Single domains with a size greater than a threshold.

    Returns:
        list: Paths to the created PDB files.
    """
    all_domains_af2bind_run = []

    # Create merged PDBs for contacting domain combinations
    for contact_domain in contacting_domain_combinations:
        domains = contact_domain.split("_")
        contact_domains_name = f"{target_pdb}_" + "_".join(domains) + ".pdb"
        output_file = f"{CONTACTING_DOMAINS}/{target_pdb}/{contact_domains_name}"
        merge_pdb_files(output_file, *[all_chains_domains_paths[domain] for domain in domains])
        all_domains_af2bind_run.append(output_file)

    # Add single domains greater than threshold
    for domain in single_greater_than_th:
        single_domain_name = f"{target_pdb}_{domain}.pdb"
        output_file = f"{CONTACTING_DOMAINS}/{target_pdb}/{single_domain_name}"
        copy_pdb(all_chains_domains_paths[domain], output_file)
        all_domains_af2bind_run.append(output_file)

    print(f"Single domains greater than threshold: {single_greater_than_th}")
    return all_domains_af2bind_run


def remove_residues_from_pdb(pdb_path, residues_to_remove, output_path):
    """
    Removes specified residues from a PDB file and saves the modified structure.

    Parameters:
    pdb_path (str): Path to the input PDB file.
    residues_to_remove (set): Set of residues to be removed. Format: {1, 2, 3, 4, 5, ...}.
                              Example: {100, 50} or {1, 2, 3, 4, 5}
    output_path (str): Path to save the modified PDB file.
    """
    # Parse the input PDB file
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("structure", pdb_path)

    # Iterate through all chains and residues, removing specified residues
    for model in structure:
        for chain in model:
            chain_id = chain.id
            residues_to_keep = []
            for residue in chain:
                res_id = residue.id[1]  # Residue ID (integer part)
                if res_id not in residues_to_remove:
                    residues_to_keep.append(residue)

            # Update the chain's residues to only keep the desired ones
            chain.child_list = residues_to_keep

    # Save the modified structure
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_path)

In [None]:
#@title **Run AF2Bind 🔬**

# Globals
global flattened_low_bfactor_stretches
global MAX_SUM_RES_SUBGRAPH
global THRESHOLD_SINGLE_PAIRS
global af2_struct
# Subgraph search parameters for large PDBs
MAX_SUM_RES_SUBGRAPH = 300 # max sum for all subgraphs
THRESHOLD_SINGLE_PAIRS = 300 # threshold for single domains & pairs of domains

# Paths
SINGLE_CHAINS_PATH="af2bind_temp/single_chains"
DOMAINS_PATH="af2bind_temp/domains"
CONTACTING_DOMAINS="af2bind_temp/contacting_domains"
PREDS_PATH="af2bind_temp/preds"

# Inputs
target_pdb = "7LQ6" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}

#@markdown - Please indicate target pdb (or uniprot ID to download from AlphaFoldDB) and chain.
#@markdown - Leave pdb blank for custom upload prompt.
large_pdb = True # @param {"type":"boolean"}
#@markdown - (optional) For large PDBs (>300 res), we offer a separate pipeline that splits the PDB into domains and finds the optimal combinations to amplify the signal.
af2_struct = False # @param {"type":"boolean"}
#@markdown - (optional) For large AF2-predicted structures (>300 res), low pLDDT regions can be removed.

mask_sidechains = True
mask_sequence = False

target_pdb = target_pdb.replace(" ","")
target_chain = target_chain.replace(" ","")
if target_chain == "":
  target_chain = "A"

pdb_filename = get_pdb(target_pdb)
print(pdb_filename)
if(target_pdb==""):
  target_pdb="tmp"


# Prepare Temporary Directories
try:
  os.makedirs("af2bind_temp", exist_ok=True)
  os.makedirs(SINGLE_CHAINS_PATH, exist_ok=True)
  os.makedirs(DOMAINS_PATH, exist_ok=True)
  os.makedirs(CONTACTING_DOMAINS, exist_ok=True)
  os.makedirs(PREDS_PATH, exist_ok=True)

except Exception as e:
    print(f"An error occurred: {e}")

#delete the directory, if it already exists
setup_directory(f"{DOMAINS_PATH}/{target_pdb}")
setup_directory(f"{SINGLE_CHAINS_PATH}/{target_pdb}")
setup_directory(f"{CONTACTING_DOMAINS}/{target_pdb}")


if(large_pdb):

  print("###")
  print("## The large PDB pipeline is applied")
  print("###\n")


  out_process_pdb_file= process_pdb_file(target_pdb,target_chain)

  if af2_struct:
    low_bfactor_stretches = get_low_bfactor_stretches(out_process_pdb_file["input_pdb_file"])
  else:
    low_bfactor_stretches = set()

  flattened_low_bfactor_stretches = set(item for sublist in low_bfactor_stretches for item in sublist)

  af2bind_pipeline(
      PREDS_PATH,
      target_chain,
      out_process_pdb_file,
      out_process_pdb_file["input_pdb_file"],
      flattened_low_bfactor_stretches
  )

  csv_file_path=f"{PREDS_PATH}/{target_pdb}/domains_lumped_{target_pdb}.csv"
  df = pd.read_csv(csv_file_path)
  df['rank'] = range(len(df))
  df = df[['rank'] + [col for col in df.columns if col != 'rank']]  # Ensures 'rank' is the first column
  df.to_csv('results.csv')
  data_table.enable_dataframe_formatter()
  df_sorted = df.sort_values("p(bind)",ascending=False, ignore_index=True)
  display(data_table.DataTable(df_sorted, min_width=100, num_rows_per_page=15, include_index=False))

  top_n = 15
  top_n_idx = df['p(bind)'].argsort()[::-1][:top_n]  # Get indices of the top N p(bind) values
  pymol_cmd = "select ch" + str(target_chain) + ","

  for n, i in enumerate(top_n_idx):
      # Extracting values directly from the DataFrame
      p = df['p(bind)'].iloc[i]  # p(bind) value
      c = df['chain'].iloc[i]       # Chain value
      r = df['resi'].iloc[i]        # Residue value

      pymol_cmd += f" resi {r}"
      if n < top_n - 1:
          pymol_cmd += " +"

  print("\n🧪 Pymol Selection Cmd:")
  print(pymol_cmd)

  #clear the cache
  gc.collect()
  jax.clear_caches()

else:

  if af2_struct:
    low_bfactor_stretches = get_low_bfactor_stretches(pdb_filename)
    flattened_low_bfactor_stretches = set(item for sublist in low_bfactor_stretches for item in sublist)
    remove_residues_from_pdb(pdb_filename,flattened_low_bfactor_stretches,pdb_filename)

  else:
    low_bfactor_stretches = set()


  clear_mem()
  af_model = mk_afdesign_model(protocol="binder", debug=True)
  af_model.prep_inputs(pdb_filename=pdb_filename,
                      chain=target_chain,
                      binder_len=20,
                      rm_target_sc=mask_sidechains,
                      rm_target_seq=mask_sequence)

  # split
  r_idx = af_model._inputs["residue_index"][-20] + (1 + np.arange(20)) * 50
  af_model._inputs["residue_index"][-20:] = r_idx.flatten()

  af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
  af_model.predict(verbose=False)

  o = af2bind(af_model.aux["debug"]["outputs"],
              mask_sidechains=mask_sidechains)
  pred_bind = o["p_bind"].copy()
  pred_bind_aa = o["p_bind_aa"].copy()

  #######################################################
  labels = ["chain","resi","resn","p(bind)"]
  data = []
  for i in range(af_model._target_len):
    c = af_model._pdb["idx"]["chain"][i]
    r = af_model._pdb["idx"]["residue"][i]
    a = aa_order.get(af_model._pdb["batch"]["aatype"][i],"X")
    p = pred_bind[i]
    data.append([c,r,a,p])

  df = pd.DataFrame(data, columns=labels)
  df.to_csv('results.csv')

  data_table.enable_dataframe_formatter()
  df_sorted = df.sort_values("p(bind)",ascending=False, ignore_index=True).rename_axis('rank').reset_index()
  display(data_table.DataTable(df_sorted, min_width=100, num_rows_per_page=15, include_index=False))

  top_n = 15
  top_n_idx = pred_bind.argsort()[::-1][:15]
  pymol_cmd="select ch"+str(target_chain)+","
  for n,i in enumerate(top_n_idx):
    p = pred_bind[i]
    c = af_model._pdb["idx"]["chain"][i]
    r = af_model._pdb["idx"]["residue"][i]
    pymol_cmd += f" resi {r}"
    if n < top_n-1:
      pymol_cmd += " +"

  print("\n🧪Pymol Selection Cmd:")
  print(pymol_cmd)

In [None]:
#@title **Display Structure** (Colored by Confidence)

if large_pdb:

  out_pdb_path = f"{PREDS_PATH}/{target_pdb}/{target_pdb}_lumped_pred.pdb"

  hbondCutoff = 4.0

  # Read the PDB file
  with open(out_pdb_path, 'r') as file:
      pdb_str = file.read()

  # Initialize 3Dmol view
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=600)
  view.addModel(pdb_str, 'pdb', {'hbondCutoff': hbondCutoff})

  # Define the color scheme for B-factors:
  color_scheme = {'prop': 'b', 'gradient': 'linear', 'min': 0, 'max': 100, 'colors': ['white', 'blue']}
  #color_scheme = {'prop': 'b', 'gradient': 'rwb', 'min': 0, 'max': 100}
  view.setStyle({'cartoon': {'colorscheme': color_scheme}})
  # Optional: Add sidechains if needed (this part is skipped as per your request)
  # You can customize this section if you want to include specific residues later.

  # Add hoverable labels showing B-factors
  view.setHoverable({}, True,
      '''function(atom,viewer,event,container){
          if(!atom.label){
              atom.label=viewer.addLabel(atom.chain+"/"+atom.resi+"/"+atom.resn+" "+(atom.b/100.0).toFixed(3),{
                  position:atom,
                  backgroundColor:'white',
                  backgroundOpacity:0.75,
                  borderColor:'black',
                  borderThickness:2.0,
                  fontColor:'black'
              });
          }
      }''',
      '''function(atom,viewer){
          if(atom.label){
              viewer.removeLabel(atom.label);
              delete atom.label;
          }
      }'''
  )

  # Zoom to the view and display
  view.zoomTo()
  view.show()

else:

  use_native_coordinates = True
  show_ligand = False


  preds_adj = pred_bind.copy()

  # replace plddt and coordinates of prediction
  L = af_model._target_len
  aux = copy.deepcopy(af_model.aux["all"])
  aux["plddt"][:,:L] = preds_adj
  if show_ligand:
    af_model.save_pdb("output.pdb",aux={"all":aux})
  else:
    aux["atom_mask"][:,L:] = 0
    x = {k:[] for k in ["aatype",
                        "residue_index",
                        "atom_positions",
                        "atom_mask",
                        "b_factors"]}
    asym_id = []
    for i in range(af_model._target_len):
      for k in ["aatype","atom_mask"]: x[k].append(aux[k][0,i])
      if use_native_coordinates:
        x["atom_positions"].append(af_model._pdb["batch"]["all_atom_positions"][i])
      else:
        x["atom_positions"].append(aux["atom_positions"][0,i])
      x["residue_index"].append(af_model._pdb["idx"]["residue"][i])
      x["b_factors"].append(x["atom_mask"][-1] * aux["plddt"][0,i] * 100.0)
      asym_id.append(af_model._pdb["idx"]["chain"][i])
    x = {k:np.array(v) for k,v in x.items()}

    # fix the chains
    (n,resnum_) = (0,None)
    pdb_lines = []
    for line in protein.to_pdb(protein.Protein(**x)).splitlines():
      if line[:4] == "ATOM":
        resnum = int(line[22:22+5])
        if resnum_ is None: resnum_ = resnum
        if resnum != resnum_:
          n += 1
          resnum_ = resnum
        pdb_lines.append("%s%s%4i%s" % (line[:21],asym_id[n],resnum,line[26:]))
    with open("output.pdb","w") as handle:
      handle.write("\n".join(pdb_lines))

  hbondCutoff = 4.0
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',width=800,height=600)
  pdb_str = open("output.pdb",'r').read()
  view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  #color_scheme = {'prop':'b','gradient': 'rwb','min':0,'max':100}
  color_scheme = {'prop': 'b', 'gradient': 'linear', 'min': 0, 'max': 100, 'colors': ['white', 'blue']}
  view.setStyle({'cartoon': {'colorscheme': color_scheme}})

  # add sidechains
  for i in range(af_model._target_len):
    c = af_model._pdb["idx"]["chain"][i]
    r = int(af_model._pdb["idx"]["residue"][i])
    p = pred_bind[i]
    if p > 0.5:
      view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':["GLY","PRO"],'invert':True},{'atom':['C','O','N'],'invert':True}]},
                    {'stick':{'colorscheme':color_scheme,'radius':0.3}})
      view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':"GLY"},{'atom':'CA'}]},
                    {'sphere':{'colorscheme':color_scheme,'radius':0.3}})
      view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                    {'stick':{'colorscheme':color_scheme,'radius':0.3}})

  view.setHoverable({}, True,
                '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(atom.chain+"/"+atom.resi+"/"+atom.resn+" "+(atom.b/100.0).toFixed(3),{position:atom,backgroundColor:'white',backgroundOpacity:0.75,borderColor:'black',borderThickness:2.0,fontColor:'black'});}}''',
                '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')

  view.zoomTo()
  view.show()

def plot_plddt_legend(dpi=100):
  thresh = ['p(bind):','0.00','0.25','0.50','0.75','1.00']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ['#FDFDFD', '#CCE5FF', '#99CCFF',  '#66B3FF',  '#3399FF',  '#0000FF']:  #["white","#FF0000","#FF8080","#FFFFFF","#8080FF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt

plot_plddt_legend().show()

In [None]:
#@title **Download Predictions**
from google.colab import files
import zipfile

if(large_pdb):
  with zipfile.ZipFile('output.zip', 'w') as zipf:
    # Add files to the zip file
    zipf.write(f"{PREDS_PATH}/{target_pdb}/{target_pdb}_lumped_pred.pdb",
                arcname=f"{target_pdb}_lumped_pred.pdb")
    zipf.write(f"{PREDS_PATH}/{target_pdb}/domains_lumped_{target_pdb}.csv",
                arcname=f"domains_lumped_{target_pdb}.csv")

  files.download(f'output.zip')
else:
  os.system(f"zip -r output.zip output.pdb results.csv")
  files.download(f'output.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#@title **Activation analysis** (optional)
#@markdown - **Not applicable to the large PDB pipeline.**

pbind_cutoff = 0.5 # @param ["0.0", "0.5", "0.9"] {type:"raw"}
blosum_map = list("CSTAGPDEQNHRKMILVWYF")
cs_label_list = list("ACDEFGHIKLMNPQRSTVWY")

indices_A_Y_mapping = np.array([cs_label_list.index(letter) for letter in blosum_map])
pred_bind_aa_blosum = pred_bind_aa[:,indices_A_Y_mapping]
filt = pred_bind > pbind_cutoff
pred_bind_aa_blosum = pred_bind_aa_blosum[filt]
res_labels = np.array(af_model._pdb["idx"]["residue"])[filt]
chain_labels = np.array(af_model._pdb["idx"]["chain"])[filt]

fig = px.imshow(pred_bind_aa_blosum.T,
                labels=dict(x="positions", y="amino acids", color="pref"),
                y=blosum_map,
                x=[f"{y}_{x}" for x,y in zip(res_labels,chain_labels)],
                zmin=-1,
                zmax=1,
                template="simple_white",
                color_continuous_scale=["red", "white", "blue"],
              )
fig.show()