# Zero-shot co-engineering of protein sequences & structrues with PT-DiT

This notebook provides examples of utilizing RePaint ([arXiv:2201.09865v4](https://arxiv.org/abs/2201.09865)) algorithm with PT-DiT, a pre-trained multimodal diffusion model, to co-engineer protein sequences (represented as amino acids) and structures (represented as ProTokens).

## 1. Import Libraries

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import pickle as pkl
import numpy as np
from tqdm import tqdm
import argparse

from functools import partial
from src.model.diffusion_transformer import DiffusionTransformer
from train.schedulers import GaussianDiffusion
import datetime
from flax.jax_utils import replicate
from functools import reduce

from configs.global_config import global_config
from configs.dit_config import dit_config
global_config.dropout_flag = False 

### Load embedding 
with open('../embeddings/protoken_emb.pkl', 'rb') as f:
    protoken_emb = jnp.array(pkl.load(f), dtype=jnp.float32)
with open('../embeddings/aatype_emb.pkl', 'rb') as f:
    aatype_emb = jnp.array(pkl.load(f), dtype=jnp.float32)

## 2. Preparation

### 2.1 Define Constants

Here we define constants for PT-DiT inference:
* `NRES`: length of proteins 
* `NSAMPLE_PER_DEVICE`: the number of proteins in each GPU device
* `DIM_EMB`: the size of raw embedding (concatenating ProTokens & amino acids)
* `NDEVICES`: the number of available GPU devices
* `BATCH_SIZE`: the number of proteins generated in each batch

In [15]:
#### constants
NRES = 512
NSAMPLE_PER_DEVICE = 8
DIM_EMB_PTK = protoken_emb.shape[-1]
DIM_EMB_AA = aatype_emb.shape[-1]
DIM_EMB = DIM_EMB_PTK + DIM_EMB_AA # 40 # 32 + 8
NDEVICES = len(jax.devices())

BATCH_SIZE = NSAMPLE_PER_DEVICE * NDEVICES

### 2.2 Define Functional Utils for Pre/Post-processing

Here we define functional utils for pre-processing inputs and post-processing outputs, for example:
* `protoken_emb_distance_fn`: calculate Euclidean distance between two ProToken embeddings 
* `aatype_emb_distance_fn`: calculate Euclidean distance between two amino acid embeddings
* `aatype_index_to_resname`: convert amino acid indexes into corresponding amino acid symbols
* `resname_to_aatype_index`: convert amino acid symbols into corresponding amino acid indexes

In [16]:
#### function utils 

def split_multiple_rng_keys(rng_key, num_keys):
    rng_keys = jax.random.split(rng_key, num_keys + 1)
    return rng_keys[:-1], rng_keys[-1]

def flatten_list_of_dicts(list_of_dicts):
    ### [{a: [1,2,3,4]}] -> [{a:1}, {a:2}, {a:3}, {a:4}]
    flattened_lists = [[{k: v[i] for k, v in d.items()} 
                        for i in range(len(next(iter(d.values()))))] for d in list_of_dicts]
    return reduce(lambda x, y: x+y, flattened_lists, [])

def protoken_emb_distance_fn(x, y):
    x_ = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
    y_ = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + 1e-6)
    
    return -jnp.sum(x_ * y_, axis=-1)

def aatype_emb_distance_fn(x, y):
    return jnp.sum((x - y) ** 2, axis=-1)

def aatype_index_to_resname(aatype_index):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    
    return "".join([restypes[int(i)] for i in aatype_index])

def resname_to_aatype_index(resnames):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    return np.array([restypes.index(a) for a in resnames], dtype=np.int32)

### 2.3 Load Model & Parameters

In [17]:
#### load model & params 
dit_model = DiffusionTransformer(
    config=dit_config, global_config=global_config
)
num_diffusion_timesteps = 500
scheduler = GaussianDiffusion(num_diffusion_timesteps=num_diffusion_timesteps)

#### rng keys
rng_key = jax.random.PRNGKey(8888)
np.random.seed(7777)

##### load params
ckpt_path = '../ckpts/PT_DiT_params_2000000.pkl'
with open(ckpt_path, "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
    
##### replicate params
params = replicate(params)

### 2.4 Define Functional Utils for Inference

Here we define functional utils for PT-DiT inference, for example:
* `clamp_x0_fn`: to enable clamping trick introduced at [arXiv:2205.14217v1](https://arxiv.org/abs/2205.14217)
* `denoise_step`: sample $p(\mathbf{z}_{t-1}|\mathbf{z}_t)$
* `q_sample`: sample $p(\mathbf{z}_t|\mathbf{z}_0)$
* `noise_step`: sample $p(\mathbf{z}_t | \mathbf{z}_{t-1})$

and define `pjit = pmap + jit` (see `jax` documentations) version of functions to enable fast and batched inference. 

In [None]:
##### main inference functions
jit_apply_fn = jax.jit(dit_model.apply)
infer_protuple = True

def clamp_x0_fn(x0):
    protoken_indexes = \
                jnp.argmin(protoken_emb_distance_fn(x0[..., None, :protoken_emb.shape[-1]], 
                                                  protoken_emb.reshape((1,)*(len(x0.shape)-1) + protoken_emb.shape)), axis=-1)
    if bool(infer_protuple):
        aatype_indexes = \
                jnp.argmin(aatype_emb_distance_fn(x0[..., None, protoken_emb.shape[-1]:], 
                                                  aatype_emb.reshape((1,)*(len(x0.shape)-1) + aatype_emb.shape)), axis=-1)
        return jnp.concatenate([protoken_emb[protoken_indexes], aatype_emb[aatype_indexes]], axis=-1)
    else:
        return protoken_emb[protoken_indexes]

def denoise_step(params, x, seq_mask, t, residue_index, rng_key,
                 clamp_x0_fn=None):
    t = jnp.full((x.shape[0],), t)
    indicator = params['params']['protoken_indicator']
    if bool(infer_protuple):
        indicator = jnp.concatenate([indicator, params['params']['aatype_indicator']], 
                                    axis=-1)
    eps_prime = jit_apply_fn({'params': params['params']['model']}, x + indicator[None, ...], 
                             seq_mask, t, tokens_rope_index=residue_index)
    mean, variance, log_variance = scheduler.p_mean_variance(x, t, eps_prime, clip=False, clamp_x0_fn=clamp_x0_fn)
    rng_key, normal_key = jax.random.split(rng_key)
    x = mean + jnp.exp(0.5 * log_variance) * jax.random.normal(normal_key, x.shape)
    return x, rng_key

def q_sample(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    eps = jax.random.normal(normal_key, x.shape, dtype=jnp.float32)
    x_t = scheduler.q_sample(x, t, eps)
    return x_t, rng_key

def noise_step(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    x = scheduler.q_sample_step(x, t, jax.random.normal(normal_key, x.shape))
    return x, rng_key

def index_from_embedding(x):
    # x: (B, Nres, Nemb)
    protoken_indexes = \
        jnp.argmin(protoken_emb_distance_fn(x[..., None, :protoken_emb.shape[-1]], 
                                            protoken_emb[None, None, ...]), axis=-1)
    ret = {'protoken_indexes': protoken_indexes}
    if bool(infer_protuple):
        aatype_indexes = \
            jnp.argmin(aatype_emb_distance_fn(x[..., None, protoken_emb.shape[-1]:], 
                                                aatype_emb[None, None, ...]), axis=-1)
        ret.update({'aatype_indexes': aatype_indexes})
        
    return ret            
    
pjit_denoise_step = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=None)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_denoise_step_clamped = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=clamp_x0_fn)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_q_sample = jax.pmap(jax.jit(q_sample), axis_name="i",
                            in_axes=(0, None, 0))
pjit_noise_step = jax.pmap(jax.jit(noise_step), axis_name="i",
                            in_axes=(0, None, 0))
pjit_index_from_embedding = jax.pmap(jax.jit(index_from_embedding), axis_name="i")

### 2.5 Define Functional Utils for RePaint Algorithm

Prepare main inference function with RePaint algorithms ([arXiv:2201.09865v4](https://arxiv.org/abs/2201.09865)). We define the following auxiliary functions and hyper-parameters: 
* `make_repaint_info`: functions to prepare repaint mask and repaint context used in repaint algorithm
* `n_eq_steps`: number of equilibirum steps in predictor-corrector tricks ([ICLR2021](https://openreview.net/forum?id=PxTIG12RRHS)), in general, more `n_eq_steps` lead to higher sampling quality
* `phasing_time`: we use phasing in diffusion inference, in first phase (large noise scale), we do not use clamping trick ([arXiv:2205.14217v1](https://arxiv.org/abs/2205.14217)), to encourage diversity; in second phase (small noise scale), we use clamping trick to ensure robustness

In [19]:
def make_repaint_info(aatypes, protokens, aatype_context_ids, protoken_context_ids):
    protoken_context = protoken_emb[protokens]
    aatype_context = aatype_emb[aatypes]
    assert len(protoken_context) == len(aatype_context), 'seq_len mismatch: {} != {}'.format(len(protoken_context), len(aatype_context))
    seq_len = len(protoken_context)
    
    repaint_context = np.concatenate([protoken_context, aatype_context], axis=-1)
    repaint_mask_aa = np.array([[0,]*DIM_EMB_PTK+[1,]*DIM_EMB_AA if i in aatype_context_ids \
                                  else [0,]*DIM_EMB for i in range(seq_len)], 
                                dtype=np.bool_)
    repaint_mask_ptk = np.array([[1,]*DIM_EMB_PTK+[0,]*DIM_EMB_AA if i in protoken_context_ids \
                                  else [0,]*DIM_EMB for i in range(seq_len)], 
                                dtype=np.bool_)
            
    return repaint_context, np.logical_or(repaint_mask_aa, repaint_mask_ptk)

In [20]:
n_eq_steps = 50 ### more n eq steps -> higher quality, in RePaint, we recommand more n eq steps
phasing_time = 250 ### controls balance between diversity & quality, larger phasing time -> higer quality, lower diversity
def run_infer(x, seq_mask, residue_index, rng_keys, 
              repaint_context=None, repaint_mask=None, repaint_time_steps=np.arange(num_diffusion_timesteps)):
    for ti in tqdm(range(num_diffusion_timesteps)):
        t = num_diffusion_timesteps - ti
        for eq_step in range(n_eq_steps):
            denoise_fn = pjit_denoise_step if t > phasing_time else pjit_denoise_step_clamped
            x, rng_keys = denoise_fn(params, x, seq_mask, t, residue_index, rng_keys)
            x, rng_keys = pjit_noise_step(x, t, rng_keys)
            
            if repaint_context is not None and t in repaint_time_steps:
                repaint_context_ = repaint_context[..., t-1] if len(repaint_context.shape) > len(x.shape) \
                                    else repaint_context
                repaint_mask_ = repaint_mask[..., t-1] if len(repaint_mask.shape) > len(x.shape) \
                                    else repaint_mask
                
                repaint_context_t, rng_keys = pjit_q_sample(repaint_context_, t, rng_keys)
                x = repaint_mask_ * repaint_context_t + (1 - repaint_mask_) * x
            
        x, rng_keys = denoise_fn(params, x, seq_mask, t, residue_index, rng_keys)

    ret = {'embedding': x, 'seq_mask': seq_mask, 'residue_index': residue_index}
    ret.update(pjit_index_from_embedding(x))
    
    return ret

## 3. Example Application 1: (Contextual) Inverse Folding

The following codes provide examples of (contextual) inverse folding with PT-DiT.

### 3.1 Obtain ProTokens from Structrues (in .pdb format)

To utilize PT-DiT in (contextual) inverse folding, we first need to encode protein backbone structures into ProTokens. Here we use script `infer_batch.py`. 

In [None]:
### example: 8CYK

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/inverse_folding\
            --save_dir_path {working_dir}/example_scripts/results/inverse_folding\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

In [None]:
### load context
with open('./results/inverse_folding/stage_1/generator_inputs/8CYK_B.pkl', 'rb') as f:
    data_dict = pkl.load(f)

seq_len = len(data_dict['protokens'])
seq_mask = np.ones(seq_len, dtype=np.bool_)
residue_index = np.arange(seq_len, dtype=np.int32)
protoken_context = data_dict['protokens'].astype(np.int32)
aatype_context = data_dict['aatype'].astype(np.int32)

input_dict = {
    'seq_mask': seq_mask, 'residue_index': residue_index,
}

for k, v in input_dict.items(): print(k, v.shape, v.dtype)

### 3.2 Make RePaint Information

The `make_repaint_info` function defines the repainting context and mask for RePaint algorithms. For inverse folding tasks, all structural contexts (represented as ProTokens) remain fixed. Partial sequence contexts can be specified through the `aatype_context_resids` variable.

In [None]:
protoken_context_resids = np.arange(seq_len)
#### contextual inverse folding: put sequence context here
aatype_context_resids = [] 

repaint_context, repaint_mask = make_repaint_info(aatype_context, protoken_context, aatype_context_resids, protoken_context_resids)

repaint_dict = {
    'repaint_context': repaint_context, 'repaint_mask': repaint_mask.astype(np.float32)
}
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

### 3.3 Run Inference

In [None]:
### preprocessing inputs

def reshape_tile_pad_x(x):
    x_shape = x.shape
    x = np.pad(x, ((0, NRES - x_shape[0]), ) + ((0,0),) * (len(x_shape) - 1))
    x_shape = x.shape
    
    x = np.tile(x[None, ...], (BATCH_SIZE, ) + (1, ) * len(x_shape))
    x = x.reshape(NDEVICES, NSAMPLE_PER_DEVICE, *x_shape)
    return x

input_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), input_dict
)
repaint_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), repaint_dict
)

init_key, rng_key = jax.random.split(rng_key)
x = jax.random.normal(init_key, (NDEVICES, NSAMPLE_PER_DEVICE, NRES, DIM_EMB))
input_dict['x'] = x

rng_keys, rng_key = split_multiple_rng_keys(rng_key, NDEVICES)
rng_keys = jnp.reshape(rng_keys, (NDEVICES, -1))

for k, v in input_dict.items(): print(k, v.shape, v.dtype)
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

In [None]:
ret = run_infer(input_dict['x'], input_dict['seq_mask'], input_dict['residue_index'], rng_keys,
                repaint_dict['repaint_context'], repaint_dict['repaint_mask'])

ret = jax.tree_util.tree_map(lambda x: np.array(x.reshape(BATCH_SIZE, *x.shape[2:])).tolist(), ret)
with open('results/inverse_folding/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/inverse_folding/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)

### 3.4 Summary

In [None]:
for i, r in enumerate(ret_):
    protoken_idx = np.array(r['protoken_indexes'])[:seq_len]
    aatype_idx = np.array(r['aatype_indexes'])[:seq_len]
    print('seq{}: {}'.format(i, ''.join(aatype_index_to_resname(aatype_idx))))

## 4. Example Application 2: Contextual Protein Design

The following codes provide examples of contextual protein design with PT-DiT. The algorithm co-designs structures and sequences while accommodating optional structural and sequence constraints as design contexts

### 4.1 Obtain ProTokens from Structrues (in .pdb format)

To utilize PT-DiT in (contextual) protein design, we first need to encode template protein backbone structures into ProTokens, which serve as context during inference. Here we use script `infer_batch.py`. 

In [None]:
### example: 5jxe_G_VH 

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/contextual_scaffolding\
            --save_dir_path {working_dir}/example_scripts/results/contextual_scaffolding\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

In [None]:
### load context
with open('./results/contextual_scaffolding/stage_1/generator_inputs/5jxe_G_VH.pkl', 'rb') as f:
    data_dict = pkl.load(f)

seq_len = len(data_dict['protokens'])
seq_mask = np.ones(seq_len, dtype=np.bool_)
residue_index = np.arange(seq_len, dtype=np.int32)
protoken_context = data_dict['protokens'].astype(np.int32)
aatype_context = data_dict['aatype'].astype(np.int32)

input_dict = {
    'seq_mask': seq_mask, 'residue_index': residue_index,
}

for k, v in input_dict.items(): print(k, v.shape, v.dtype)

### 4.2 Select CDR3 as context

In this demonstration, the CDR3 region of an antibody serves as the design context. We first parse structural annotations to identify CDR and FWR regions, then specifically extract CDR3 residues. 

Design constraints can be modified by adjusting the `aatype_context_resids` (sequence context) and `protoken_context_resids` (structural context) variables.

In [44]:
import ast

def parse_annotation(annotation_file, aatype_indexes):
    seq_str = aatype_index_to_resname(aatype_indexes)
    with open(annotation_file, 'r') as f:
        contents = f.readlines()
        annotation_dict = ast.literal_eval(contents[0])
    annotation_resid_dict = {}
    for k, v in annotation_dict.items():
        start_id = seq_str.find(v)
        if start_id == -1:
            raise ValueError('can not find {} in {}'.format(v, seq_str))
        end_id = start_id + len(v)
        annotation_resid_dict[k] = np.arange(start_id, end_id)
        
    return annotation_resid_dict

In [None]:
annotation_resid_dict = parse_annotation('./results/contextual_scaffolding/5jxe_G_VH_annotation.txt', 
                                         aatype_context)

### select H-CDR3 as context
protoken_context_resids = annotation_resid_dict['H-CDR3']
aatype_context_resids = annotation_resid_dict['H-CDR3']

repaint_context, repaint_mask = make_repaint_info(aatype_context, protoken_context, aatype_context_resids, protoken_context_resids)

repaint_dict = {
    'repaint_context': repaint_context, 'repaint_mask': repaint_mask.astype(np.float32)
}
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

### 4.3 Run Inference

In [None]:
### preprocessing inputs

def reshape_tile_pad_x(x):
    x_shape = x.shape
    x = np.pad(x, ((0, NRES - x_shape[0]), ) + ((0,0),) * (len(x_shape) - 1))
    x_shape = x.shape
    
    x = np.tile(x[None, ...], (BATCH_SIZE, ) + (1, ) * len(x_shape))
    x = x.reshape(NDEVICES, NSAMPLE_PER_DEVICE, *x_shape)
    return x

input_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), input_dict
)
repaint_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), repaint_dict
)

init_key, rng_key = jax.random.split(rng_key)
x = jax.random.normal(init_key, (NDEVICES, NSAMPLE_PER_DEVICE, NRES, DIM_EMB))
input_dict['x'] = x

rng_keys, rng_key = split_multiple_rng_keys(rng_key, NDEVICES)
rng_keys = jnp.reshape(rng_keys, (NDEVICES, -1))

for k, v in input_dict.items(): print(k, v.shape, v.dtype)
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

In [None]:
ret = run_infer(input_dict['x'], input_dict['seq_mask'], input_dict['residue_index'], rng_keys,
                repaint_dict['repaint_context'], repaint_dict['repaint_mask'])

ret = jax.tree_util.tree_map(lambda x: np.array(x.reshape(BATCH_SIZE, *x.shape[2:])).tolist(), ret)
with open('results/contextual_scaffolding/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/contextual_scaffolding/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)

### 4.4 Decode Structures

Here we decode 3D coordinates of protein structures from generated ProTokens (with `decode_structure.py` script), and save designed sequences & structures in .pdb format. pdb files are saved in `results/contextual_scaffolding/pdb`

In [None]:
working_dir = '../'

os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/decode_structure.py\
                --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
                --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
                --input_path results/contextual_scaffolding/result_flatten.pkl\
                --output_dir results/contextual_scaffolding/pdb\
                --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl\
                --padding_len {NRES}''')