# *De-novo* co-design of protein sequences & structrues with PT-DiT

This notebook provides examples of utilizing PT-DiT, a pre-trained multimodal diffusion model, to co-design protein sequences (represented as amino acids) and structures (represented as ProTokens).

## 1. Import Libraries

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import pickle as pkl
import numpy as np
from tqdm import tqdm
import argparse

from functools import partial
from src.model.diffusion_transformer import DiffusionTransformer
from train.schedulers import GaussianDiffusion
import datetime
from flax.jax_utils import replicate
from functools import reduce

from configs.global_config import global_config
from configs.dit_config import dit_config
global_config.dropout_flag = False 

### Load embedding 
with open('../embeddings/protoken_emb.pkl', 'rb') as f:
    protoken_emb = jnp.array(pkl.load(f), dtype=jnp.float32)
with open('../embeddings/aatype_emb.pkl', 'rb') as f:
    aatype_emb = jnp.array(pkl.load(f), dtype=jnp.float32)

## 2. Preparation

### 2.1 Define Constants

Here we define constants for PT-DiT inference:
* `NRES`: length of proteins 
* `NSAMPLE_PER_DEVICE`: the number of proteins in each GPU device
* `DIM_EMB`: the size of raw embedding (concatenating ProTokens & amino acids)
* `NDEVICES`: the number of available GPU devices
* `BATCH_SIZE`: the number of proteins generated in each batch

In [2]:
#### constants
NRES = 256
NSAMPLE_PER_DEVICE = 8
DIM_EMB = protoken_emb.shape[-1] + aatype_emb.shape[-1] # 40 # 32 + 8
NDEVICES = len(jax.devices())

BATCH_SIZE = NSAMPLE_PER_DEVICE * NDEVICES

### 2.2 Define Functional Utils for Pre/Post-processing

Here we define functional utils for pre-processing inputs and post-processing outputs, for example:
* `protoken_emb_distance_fn`: calculate Euclidean distance between two ProToken embeddings 
* `aatype_emb_distance_fn`: calculate Euclidean distance between two amino acid embeddings
* `aatype_index_to_resname`: convert amino acid indexes into corresponding amino acid symbols
* `resname_to_aatype_index`: convert amino acid symbols into corresponding amino acid indexes

In [3]:
#### function utils 

def split_multiple_rng_keys(rng_key, num_keys):
    rng_keys = jax.random.split(rng_key, num_keys + 1)
    return rng_keys[:-1], rng_keys[-1]

def flatten_list_of_dicts(list_of_dicts):
    ### [{a: [1,2,3,4]}] -> [{a:1}, {a:2}, {a:3}, {a:4}]
    flattened_lists = [[{k: v[i] for k, v in d.items()} 
                        for i in range(len(next(iter(d.values()))))] for d in list_of_dicts]
    return reduce(lambda x, y: x+y, flattened_lists, [])

def protoken_emb_distance_fn(x, y):
    x_ = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
    y_ = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + 1e-6)
    
    return -jnp.sum(x_ * y_, axis=-1)

def aatype_emb_distance_fn(x, y):
    return jnp.sum((x - y) ** 2, axis=-1)

def aatype_index_to_resname(aatype_index):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    
    return "".join([restypes[int(i)] for i in aatype_index])

def resname_to_aatype_index(resnames):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    return np.array([restypes.index(a) for a in resnames], dtype=np.int32)

### 2.3 Load Model & Parameters

In [4]:
#### load model & params 
dit_model = DiffusionTransformer(
    config=dit_config, global_config=global_config
)
num_diffusion_timesteps = 500
scheduler = GaussianDiffusion(num_diffusion_timesteps=num_diffusion_timesteps)

#### rng keys
rng_key = jax.random.PRNGKey(8888)
np.random.seed(7777)

##### load params
ckpt_path = '../ckpts/PT_DiT_params_2000000.pkl'
with open(ckpt_path, "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
    
##### replicate params
params = replicate(params)

### 2.4 Define Functional Utils for Inference

Here we define functional utils for PT-DiT inference, for example:
* `clamp_x0_fn`: to enable clamping trick introduced at [arXiv:2205.14217v1](https://arxiv.org/abs/2205.14217)
* `denoise_step`: sample $p(\mathbf{z}_{t-1}|\mathbf{z}_t)$
* `q_sample`: sample $p(\mathbf{z}_t|\mathbf{z}_0)$
* `noise_step`: sample $p(\mathbf{z}_t | \mathbf{z}_{t-1})$

and define `pjit = pmap + jit` (see `jax` documentations) version of functions to enable fast and batched inference. 

In [None]:
##### main inference functions
jit_apply_fn = jax.jit(dit_model.apply)
infer_protuple = True

def clamp_x0_fn(x0):
    protoken_indexes = \
                jnp.argmin(protoken_emb_distance_fn(x0[..., None, :protoken_emb.shape[-1]], 
                                                  protoken_emb.reshape((1,)*(len(x0.shape)-1) + protoken_emb.shape)), axis=-1)
    if bool(infer_protuple):
        aatype_indexes = \
                jnp.argmin(aatype_emb_distance_fn(x0[..., None, protoken_emb.shape[-1]:], 
                                                  aatype_emb.reshape((1,)*(len(x0.shape)-1) + aatype_emb.shape)), axis=-1)
        return jnp.concatenate([protoken_emb[protoken_indexes], aatype_emb[aatype_indexes]], axis=-1)
    else:
        return protoken_emb[protoken_indexes]

def denoise_step(params, x, seq_mask, t, residue_index, rng_key,
                 clamp_x0_fn=None):
    t = jnp.full((x.shape[0],), t)
    indicator = params['params']['protoken_indicator']
    if bool(infer_protuple):
        indicator = jnp.concatenate([indicator, params['params']['aatype_indicator']], 
                                    axis=-1)
    eps_prime = jit_apply_fn({'params': params['params']['model']}, x + indicator[None, ...], 
                             seq_mask, t, tokens_rope_index=residue_index)
    mean, variance, log_variance = scheduler.p_mean_variance(x, t, eps_prime, clip=False, clamp_x0_fn=clamp_x0_fn)
    rng_key, normal_key = jax.random.split(rng_key)
    x = mean + jnp.exp(0.5 * log_variance) * jax.random.normal(normal_key, x.shape)
    return x, rng_key

def q_sample(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    eps = jax.random.normal(normal_key, x.shape, dtype=jnp.float32)
    x_t = scheduler.q_sample(x, t, eps)
    return x_t, rng_key

def noise_step(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    x = scheduler.q_sample_step(x, t, jax.random.normal(normal_key, x.shape))
    return x, rng_key

def index_from_embedding(x):
    # x: (B, Nres, Nemb)
    protoken_indexes = \
        jnp.argmin(protoken_emb_distance_fn(x[..., None, :protoken_emb.shape[-1]], 
                                            protoken_emb[None, None, ...]), axis=-1)
    ret = {'protoken_indexes': protoken_indexes}
    if bool(infer_protuple):
        aatype_indexes = \
            jnp.argmin(aatype_emb_distance_fn(x[..., None, protoken_emb.shape[-1]:], 
                                                aatype_emb[None, None, ...]), axis=-1)
        ret.update({'aatype_indexes': aatype_indexes})
        
    return ret            
    
pjit_denoise_step = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=None)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_denoise_step_clamped = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=clamp_x0_fn)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_q_sample = jax.pmap(jax.jit(q_sample), axis_name="i",
                            in_axes=(0, None, 0))
pjit_noise_step = jax.pmap(jax.jit(noise_step), axis_name="i",
                            in_axes=(0, None, 0))
pjit_index_from_embedding = jax.pmap(jax.jit(index_from_embedding), axis_name="i")

## 3. De-novo Design

### 3.1 Prepare Main Inference Function

Here we define the main inference function of PT-DiT, and introduce two hyper-parameters may affect results:
* `n_eq_steps`: number of equilibirum steps in predictor-corrector tricks ([ICLR2021](https://openreview.net/forum?id=PxTIG12RRHS)), in general, more `n_eq_steps` lead to higher sampling quality
* `phasing_time`: we use phasing in diffusion inference, in first phase (large noise scale), we do not use clamping trick ([arXiv:2205.14217v1](https://arxiv.org/abs/2205.14217)), to encourage diversity; in second phase (small noise scale), we use clamping trick to ensure robustness

In [6]:
n_eq_steps = 50 ### more n eq steps -> higher quality
phasing_time = 250 ### controls balance between diversity & quality, larger phasing time -> higer quality, lower diversity
def run_infer(x, seq_mask, residue_index, rng_keys):
    for ti in tqdm(range(num_diffusion_timesteps)):
        t = num_diffusion_timesteps - ti
        denoise_fn = pjit_denoise_step_clamped if t < phasing_time else pjit_denoise_step
        
        for eq_step in range(n_eq_steps):
            x, rng_keys = denoise_fn(params, x, seq_mask, t, residue_index, rng_keys)
            x, rng_keys = pjit_noise_step(x, t, rng_keys)
            
        x, rng_keys = pjit_denoise_step(params, x, seq_mask, t, residue_index, rng_keys)

    ret = {'embedding': x, 'seq_mask': seq_mask, 'residue_index': residue_index}
    ret.update(pjit_index_from_embedding(x))
    
    return ret

### 3.2 Run Inference

In [None]:
rng_key = jax.random.PRNGKey(8888)

rng_key, normal_key = jax.random.split(rng_key)
x = jax.random.normal(rng_key, shape=(BATCH_SIZE, NRES, DIM_EMB), dtype=jnp.float32)
seq_mask = jnp.ones((BATCH_SIZE, NRES), dtype=jnp.bool_)
residue_index = jnp.tile(jnp.arange(NRES, dtype=jnp.int32)[None, ...], (BATCH_SIZE, 1))

### reshape inputs 
reshape_func = lambda x:x.reshape(NDEVICES, x.shape[0]//NDEVICES, *x.shape[1:])
x, seq_mask, residue_index = jax.tree.map(reshape_func, (x, seq_mask, residue_index))

print(x.shape, x.dtype)
print(seq_mask.shape, seq_mask.dtype)
print(residue_index.shape, residue_index.dtype)

In [None]:
rng_keys, rng_key = split_multiple_rng_keys(rng_key, NDEVICES)

ret = run_infer(x, seq_mask, residue_index, rng_keys)

In [9]:
ret = jax.tree_util.tree_map(lambda x:np.array(x).reshape(-1, *x.shape[2:]).tolist(), ret)
with open('results/denovo_design/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/denovo_design/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)

## 4. Decode Structures

Here we decode 3D coordinates of protein structures from generated ProTokens (with `decode_structure.py` script), and save designed sequences & structures in .pdb format. pdb files are saved in `results/denovo_design/pdb`

In [None]:
working_dir = '../'

os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/decode_structure.py\
                --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
                --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
                --input_path results/denovo_design/result_flatten.pkl\
                --output_dir results/denovo_design/pdb\
                --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl\
                --padding_len {NRES}''')