# Protein structure reconstruction based on contact map

## Setup

In [1]:
# @title Setup

# @markdown [Get your API key here](https://chroma-weights.generatebiomedicines.com) and enter it below before running.

from google.colab import output

output.enable_custom_widget_manager()

import os

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import contextlib

api_key = "YOUR_API_KEY"  # @param {type:"string"}

!pip install git+https://github.com/generatebio/chroma.git > /dev/null 2>&1

import torch

torch.use_deterministic_algorithms(True, warn_only=True)

import warnings
from tqdm import tqdm, TqdmExperimentalWarning

warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from functools import partialmethod

tqdm.__init__ = partialmethod(tqdm.__init__, leave=False)

from google.colab import files
import ipywidgets as widgets


def create_button(filename, description=""):
    button = widgets.Button(description=description)
    display(button)

    def on_button_click(b):
        files.download(filename)

    button.on_click(on_button_click)


def render(protein, trajectories=None, output="protein.cif"):
    display(protein)
    print(protein)
    protein.to_CIF(output)
    create_button(output, description="Download sample")
    if trajectories is not None:
        traj_output = output.replace(".cif", "_trajectory.cif")
        trajectories["trajectory"].to_CIF(traj_output)
        create_button(traj_output, description="Download trajectory")


import locale

locale.getpreferredencoding = lambda: "UTF-8"

from chroma import Chroma, Protein, conditioners
from chroma.models import graph_classifier, procap
from chroma.utility.api import register_key
from chroma.utility.chroma import letter_to_point_cloud, plane_split_protein

register_key(api_key)

device = "cuda" if torch.cuda.is_available() else 'cpu'
chroma = Chroma(device=device)



Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt
Loaded from cache
Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt
Loaded from cache


## Contact map conditional generation

### Conditional and unconditional generation

In [51]:
from contact_map_conditioner import ContactMapConditioner

# contact map function
def distance(X, eps=1e-6):
    """X: backbone coordinates"""
    dX = X.unsqueeze(2) - X.unsqueeze(1)
    D = torch.sqrt((dX**2).sum(-1) + eps)
    return D

# choose a protein with pdb id
PDB_ID = '2HDA'
protein = Protein(PDB_ID, canonicalize=True, device=device)

# or use a randomly generated protein
# protein = chroma.sample(
#     chain_lengths=[50],
#     langevin_factor=8,
#     inverse_temperature=8,
#     sde_func="langevin",
#     steps=500,
# )

# X: spatial coordinates of protein backbone, [batch_size, num_residue, num_atom_type, 3]
X, C, S = protein.to_XCS(device=device)

# contact map of selected protein
D_inter = distance(X)

# custom contact map conditioner
noise_schedule = chroma.backbone_network.noise_schedule
contact_conditioner = ContactMapConditioner(
    D_inter[..., 1:2],
    noise_schedule,
    weight=0.05,
    eps=1e-6,
    ca_only=True
)

# do generation
# contact map conditioned protein
contact_cond_protein = chroma.sample(
    chain_lengths=[X.size(1)],
    conditioner=contact_conditioner,
    langevin_factor=8,
    inverse_temperature=8,
    sde_func="langevin",
    steps=500,
)

# random protein with same number of residues
random_protein = chroma.sample(
    chain_lengths=[X.size(1)],
    langevin_factor=8,
    inverse_temperature=8,
    sde_func="langevin",
    steps=500,
)

Integrating SDE:   0%|          | 0/500 [00:00<?, ?it/s]

Potts Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sequential decoding:   0%|          | 0/64 [00:00<?, ?it/s]

Integrating SDE:   0%|          | 0/500 [00:00<?, ?it/s]

Potts Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sequential decoding:   0%|          | 0/64 [00:00<?, ?it/s]

### Visualization

In [52]:
# ground truth protein
render(protein)

NGLWidget()

Protein: 2HDA
> Chain A (64 residues)
MGGGVTIFVALYDYEARTTEDLSFKKGERFQIINNTEGDWWEARSIATGKNGYIPSNYVAPADS




Button(description='Download sample', style=ButtonStyle())

In [53]:
# reconstructed protein based on contact map
render(contact_cond_protein)

NGLWidget()

Protein: system
> Chain A (64 residues)
AAPPGPAGSCASGLPTDRPAAVCQCDGCETLVGASPDERTPVALCCGPDGTACQYGGGAAAPSA




Button(description='Download sample', style=ButtonStyle())

In [54]:
# randomly generated protein
render(random_protein)

NGLWidget()

Protein: system
> Chain A (64 residues)
EETKKKKELEDLCKKAVEQNIFERYQKILEKLSKEVRPLTEEEKKAIDLYDNCLYLKGKKKKKS




Button(description='Download sample', style=ButtonStyle())

### Evaluation

#### Contact map error

In [55]:
# compare contact map of generated protein with ground truth
X_cond, _, _ = contact_cond_protein.to_XCS()
D_cond = distance(X_cond)
err = (D_cond - D_inter).abs().mean()

X_rand, _, _ = random_protein.to_XCS()
D_rand = distance(X_rand)
random_err = (D_rand - D_inter).abs().mean()

print(f"random error: {random_err.item():.4} cond error: {err.item():.4}")

random error: 6.179 cond error: 2.463


#### Structural similarity

Calculate TM-align score, RMSD, and Sequence Identity between target protein and predicted protein.

In [None]:
# install TM-align
# skip if already installed
!wget https://zhanggroup.org/RNA-align/TMalign/TMalign.zip
!unzip TMalign.zip

In [49]:
# verify installation
import subprocess, shlex

def extract_metrics(tm_output):
    """
    extract RMSD, TM-score, Seq-Id from TMalign output
    """
    # locate metrics
    rmsd_start = tm_output.find(b"RMSD")
    rmsd_end = tm_output.find(b",", rmsd_start)
    tm1_start = tm_output.find(b"TM-score")
    tm1_end = tm_output.find(b"(", tm1_start)
    seq_id_start = tm_output.find(b'Seq_ID=n_identical/n_aligned=', rmsd_end) + len('Seq_ID=n_identical/n_aligned=')
    seq_id_end = tm1_start
    tm2_start = tm_output.find(b"TM-score", tm1_end)
    tm2_end = tm_output.find(b"(", tm2_start)
    # extract metrics
    rmsd = float(tm_output[rmsd_start+5:rmsd_end])
    seq_id = float(tm_output[seq_id_start:seq_id_end])
    tm_score = float(tm_output[tm1_start+9:tm1_end]), float(tm_output[tm2_start+9:tm2_end])
    return {'rmsd': rmsd, 'tm_score': tm_score, 'seq_id': seq_id}


tm_align = './TMalign/TMalign'
cmd = f"{tm_align} ./TMalign/PDB1.pdb ./TMalign/PDB2.pdb"
test_tm_output = subprocess.check_output(shlex.split(cmd))
test_metrics = extract_metrics(test_tm_output)
print(test_tm_output.decode())
print(test_metrics)


 **********************************************************************
 * TM-align (Version 20190425): protein and RNA structure alignment   *
 * References: Y Zhang, J Skolnick. Nucl Acids Res 33, 2302-9 (2005)  *
 *             S Gong, C Zhang, Y Zhang. Bioinformatics, bz282 (2019) *
 * Please email comments and suggestions to yangzhanglab@umich.edu    *
 **********************************************************************

Name of Chain_1: ./TMalign/PDB1.pdb (to be superimposed onto Chain_2)
Name of Chain_2: ./TMalign/PDB2.pdb
Length of Chain_1: 250 residues
Length of Chain_2: 166 residues

Aligned length= 119, RMSD=   2.20, Seq_ID=n_identical/n_aligned= 0.824
TM-score= 0.42654 (if normalized by length of Chain_1, i.e., LN=250, d0=5.85)
TM-score= 0.61629 (if normalized by length of Chain_2, i.e., LN=166, d0=4.80)
(You should use TM-score normalized by length of the reference structure)

(":" denotes residue pairs of d <  5.0 Angstrom, "." denotes other aligned residues)
CQDVVQDV

In [56]:
# calculate metrics
protein.to_PDB('gt.pdb')
contact_cond_protein.to_PDB('cond.pdb')
random_protein.to_PDB('rand.pdb')

cmd = f"{tm_align} gt.pdb cond.pdb"
metrics_cond = extract_metrics(subprocess.check_output(shlex.split(cmd)))
cmd = f"{tm_align} gt.pdb rand.pdb"
metrics_rand = extract_metrics(subprocess.check_output(shlex.split(cmd)))

print(f"cond: {metrics_cond}")
print(f"rand: {metrics_rand}")

cond: {'rmsd': 2.78, 'tm_score': (0.56338, 0.53962), 'seq_id': 0.053}
rand: {'rmsd': 4.0, 'tm_score': (0.24682, 0.23707), 'seq_id': 0.069}
