In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import pickle as pkl
import numpy as np
from tqdm import tqdm
import argparse

from functools import partial
from src.model.diffusion_transformer import DiffusionTransformer
from train.schedulers import GaussianDiffusion
import datetime
from flax.jax_utils import replicate
from functools import reduce

from configs.global_config import global_config
from configs.dit_config import dit_config
global_config.dropout_flag = False 

### Load embedding 
with open('../embeddings/protoken_emb.pkl', 'rb') as f:
    protoken_emb = jnp.array(pkl.load(f), dtype=jnp.float32)
with open('../embeddings/aatype_emb.pkl', 'rb') as f:
    aatype_emb = jnp.array(pkl.load(f), dtype=jnp.float32)

## Preparation

In [2]:
#### constants
NRES = 512 ### test in ProToken paper (2024.07) Nres = 256
NSAMPLE_PER_DEVICE = 8
DIM_EMB = protoken_emb.shape[-1] + aatype_emb.shape[-1] # 40 # 32 + 8
NDEVICES = len(jax.devices())

BATCH_SIZE = NSAMPLE_PER_DEVICE * NDEVICES

In [3]:
#### function utils 

def split_multiple_rng_keys(rng_key, num_keys):
    rng_keys = jax.random.split(rng_key, num_keys + 1)
    return rng_keys[:-1], rng_keys[-1]

def flatten_list_of_dicts(list_of_dicts):
    ### [{a: [1,2,3,4]}] -> [{a:1}, {a:2}, {a:3}, {a:4}]
    flattened_lists = [[{k: v[i] for k, v in d.items()} 
                        for i in range(len(next(iter(d.values()))))] for d in list_of_dicts]
    return reduce(lambda x, y: x+y, flattened_lists, [])

def protoken_emb_distance_fn(x, y):
    x_ = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
    y_ = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + 1e-6)
    
    return -jnp.sum(x_ * y_, axis=-1)

def aatype_emb_distance_fn(x, y):
    return jnp.sum((x - y) ** 2, axis=-1)

def aatype_index_to_resname(aatype_index):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    
    return "".join([restypes[int(i)] for i in aatype_index])

def resname_to_aatype_index(resnames):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    return np.array([restypes.index(a) for a in resnames], dtype=np.int32)

In [4]:
#### load model & params 
dit_model = DiffusionTransformer(
    config=dit_config, global_config=global_config
)
num_diffusion_timesteps = 500
scheduler = GaussianDiffusion(num_diffusion_timesteps=num_diffusion_timesteps)

#### rng keys
rng_key = jax.random.PRNGKey(8888)
np.random.seed(7777)

##### load params
ckpt_path = '../ckpts/PT_DiT_params_1000000.pkl'
with open(ckpt_path, "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)

In [5]:
##### main inference functions
jit_apply_fn = jax.jit(dit_model.apply)
infer_protuple = True

def index_from_embedding(x):
    # x: (B, Nres, Nemb)
    protoken_indexes = \
        jnp.argmin(protoken_emb_distance_fn(x[..., None, :protoken_emb.shape[-1]], 
                                            protoken_emb[None, None, ...]), axis=-1)
    ret = {'protoken_indexes': protoken_indexes}
    if bool(infer_protuple):
        aatype_indexes = \
            jnp.argmin(aatype_emb_distance_fn(x[..., None, protoken_emb.shape[-1]:], 
                                                aatype_emb[None, None, ...]), axis=-1)
        ret.update({'aatype_indexes': aatype_indexes})
        
    return ret            
    
pjit_index_from_embedding = jax.pmap(jax.jit(index_from_embedding), axis_name="i")

## PF-ODE Utils

In [6]:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

def ode_drift(x, t, seq_mask, residue_index):
    t_arr = jnp.full((x.shape[0],), t)
    indicator = params['params']['protoken_indicator']
    if bool(infer_protuple):
        indicator = jnp.concatenate([indicator, params['params']['aatype_indicator']], 
                                    axis=-1)
    eps_prime = jit_apply_fn({'params': params['params']['model']}, x + indicator[None, ...], 
                            seq_mask, t_arr, tokens_rope_index=residue_index)

    beta_t = scheduler.betas[jnp.int32(t)]
    sqrt_one_minus_alphas_cumprod_t = scheduler.sqrt_one_minus_alphas_cumprod[jnp.int32(t)]
    
    return 0.5 * beta_t * (-x + 1.0 / sqrt_one_minus_alphas_cumprod_t * eps_prime)

In [7]:
rtol, atol, method = 1e-5, 1e-5, "RK45"

def solve_ode(t_0, t_1, dt0, x_0, seq_mask, residue_index):
    term = ODETerm(lambda t, y, args: jax.jit(ode_drift)(y, t, seq_mask, residue_index))
    solver = Dopri5()
    stepsize_controller = PIDController(rtol=rtol, atol=atol)

    sol = diffeqsolve(term, solver, t0=t_0, t1=t_1, y0=x_0, dt0=dt0,
                        stepsize_controller=stepsize_controller, max_steps=65536)
    
    return sol.ys[-1]

pjit_solve_ode = jax.pmap(jax.jit(solve_ode), axis_name='i', in_axes=(None, None, None, 0, 0, 0))

## Latent Interpolation

In [None]:
### Example: MurD close -> open

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/latent_interpolation/raw_pdbs\
            --save_dir_path {working_dir}/example_scripts/results/latent_interpolation/raw_pdbs\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

In [47]:
protoken_dir = 'results/latent_interpolation/raw_pdbs/stage_1/generator_inputs'
pdb_list = ['MurD_close.pdb', ] * (BATCH_SIZE // 2) + \
           ['MurD_open.pdb', ] * (BATCH_SIZE // 2)

data_dicts = []
for pdb_file in pdb_list:
    protoken_file = os.path.join(protoken_dir, pdb_file.replace('.pdb', '.pkl'))
    with open(protoken_file, 'rb') as f:
        data = pkl.load(f)
    
    seq_len = data['seq_len']
    embedding = np.concatenate(
        [protoken_emb[data['protokens'].astype(np.int32)], 
         aatype_emb[data['aatype'].astype(np.int32)]], axis=-1
    )
    embedding = np.pad(embedding, ((0, NRES - seq_len), (0,0)))
    data_dicts.append(
        {'embedding': embedding, 
         'seq_mask': np.pad(data['seq_mask'], (0, NRES - seq_len)).astype(np.bool_), 
         'residue_index': np.pad(data['residue_index'], (0, NRES - seq_len)).astype(np.int32),}
    )
    
data_dict = {k: np.stack([d[k] for d in data_dicts], axis=0) for k in data_dicts[0].keys()}

### for pmap: reshape inputs
reshape_func = lambda x:x.reshape(NDEVICES, x.shape[0]//NDEVICES, *x.shape[1:])
data_dict = jax.tree_util.tree_map(reshape_func, data_dict)

### Forward PF-ODE: data -> latent

In [48]:
x0 = data_dict['embedding']
xT = pjit_solve_ode(0, scheduler.num_timesteps, 1.0, x0, data_dict['seq_mask'], data_dict['residue_index'])
xT_np = np.array(xT)

### Interpolation

In [49]:
lambda_arr = np.linspace(0, 1, BATCH_SIZE)

xT = xT.reshape(BATCH_SIZE, NRES, DIM_EMB)
xT_A, xT_B = xT[0], xT[-1] ## end-point
xT_interpolation = []

for i in range(BATCH_SIZE):
    xT_interpolation.append(
        ((1.0 - lambda_arr[i]) * xT_A + lambda_arr[i] * xT_B)
    )
xT_interpolation = jnp.array(xT_interpolation).reshape(NDEVICES, NSAMPLE_PER_DEVICE, NRES, DIM_EMB)

### Backward PF-ODE: latent -> data 

In [50]:
x0_interpolation = pjit_solve_ode(scheduler.num_timesteps, 0, -1.0, 
                                  xT_interpolation, data_dict['seq_mask'], data_dict['residue_index'])

### Decode Structures

In [51]:
ret = {'embedding': x0_interpolation, 'seq_mask': data_dict['seq_mask'], 'residue_index': data_dict['residue_index']}
ret.update(pjit_index_from_embedding(ret['embedding']))

ret = jax.tree_util.tree_map(lambda x: np.array(x.reshape(BATCH_SIZE, *x.shape[2:])).tolist(), ret)
with open('results/latent_interpolation/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/latent_interpolation/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)
    
os.makedirs('results/latent_interpolation/interpolation_pdb', exist_ok=True)

In [None]:
working_dir = '../'

os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/decode_structure.py\
                --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
                --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
                --input_path results/latent_interpolation/result_flatten.pkl\
                --output_dir results/latent_interpolation/interpolation_pdb\
                --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl\
                --padding_len {NRES}''')

## Latent Directed Evolution (Simulated)

In [None]:
### Example: cas12f

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/latent_directed_evo/cas12f\
            --save_dir_path {working_dir}/example_scripts/results/latent_directed_evo/cas12f\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

### Load Dataset

In [None]:
seqs, seq_fitness_dict = [], {}
ordered_seq_neg, ordered_seq_pos = [], []
### load dataset (variant seqs/fitness)
with open(f'results/latent_directed_evo/cas12f/cas12f_negative_sequences_with_fitness_scaled.txt', 'r') as f:
    for line in f.readlines():
        seq, fitness = line.split(';')
        seqs.append(seq.strip())
        ordered_seq_neg.append(seq.strip())
        seq_fitness_dict[seq] = float(fitness.strip())
with open(f'results/latent_directed_evo/cas12f/cas12f_positive_sequences_with_fitness_scaled.txt', 'r') as f:
    for line in f.readlines():
        seq, fitness = line.split(';')
        seqs.append(seq.strip())
        ordered_seq_pos.append(seq.strip())
        seq_fitness_dict[seq] = float(fitness.strip())
ordered_seq_pos_set = set(ordered_seq_pos)
ordered_seq_neg_set = set(ordered_seq_neg)

print("# variants: {}".format(len(seqs)))
NUM_BATCHES = len(seqs) // BATCH_SIZE + 1
seqs = seqs + [seqs[-1], ] * (NUM_BATCHES * BATCH_SIZE - len(seqs))
seq_len = len(seqs[0])
print("# BATCHES: {}, seq_len: {}".format(NUM_BATCHES, seq_len))

### Forward PF-ODE: data -> latent

In [None]:
seq_embedding_dict = {}
for b in tqdm(range(NUM_BATCHES)):
    data_dicts = []
    for i, seq in enumerate(seqs[b*BATCH_SIZE:(b+1)*BATCH_SIZE]):
        protoken_file = 'results/latent_directed_evo/cas12f/stage_1/generator_inputs/af3_model.pkl'

        with open(protoken_file, 'rb') as f:
            data = pkl.load(f)
        
        aatypes = resname_to_aatype_index(seq.strip())
        seq_len = data['seq_len']
        embedding = np.concatenate(
            [protoken_emb[data['protokens'].astype(np.int32)], 
             aatype_emb[aatypes]], axis=-1
        )
        embedding = np.pad(embedding, ((0, NRES - seq_len), (0,0)))
        data_dicts.append(
            {'embedding': embedding, 
             'seq_mask': np.pad(data['seq_mask'], (0, NRES - seq_len)).astype(np.bool_), 
             'residue_index': np.pad(data['residue_index'], (0, NRES - seq_len)).astype(np.int32),}
        )
        
    data_dict = {k: np.stack([d[k] for d in data_dicts], axis=0) for k in data_dicts[0].keys()}

    ### for pmap: reshape inputs
    reshape_func = lambda x:x.reshape(NDEVICES, x.shape[0]//NDEVICES, *x.shape[1:])
    data_dict = jax.tree_util.tree_map(reshape_func, data_dict)
    
    ### forward ODE: data->Gaussian 
    x0 = data_dict['embedding']
    xT = pjit_solve_ode(0, scheduler.num_timesteps, 1.0, x0, data_dict['seq_mask'], data_dict['residue_index'])
    xT_np = np.array(xT).reshape(BATCH_SIZE, NRES, DIM_EMB)[:, :seq_len, :]
    
    seq_embedding_dict.update(
        {seq: emb for seq, emb in zip(seqs[b*BATCH_SIZE:(b+1)*BATCH_SIZE], xT_np)}
    )

In [12]:
with open('./results/latent_directed_evo/cas12f/seq_emb_dict.pkl', 'wb') as f:
    pkl.dump(seq_embedding_dict, f)

### Simulated Directed Evolution

In [31]:
from sklearn.ensemble import RandomForestRegressor

NROUND = 10
N_SAMPLE_PER_ROUND = 10
N_TOP_SEQS = 10
f_result = open(f'./results/latent_directed_evo/cas12f/simulation_result.csv', 'w')
f_result.writelines('simulation_num, round_num, median_activity_scaled,top_activity_scaled,activity_binary_percentage\n')

In [32]:
def convert_embedding(x, max_order=2):
    cent_rep = np.mean(x, axis=-2)
    reps = [cent_rep]
    for o in range(2, max_order+1):
        reps.append(np.mean((x - cent_rep[None, :]) ** o,  axis=-2))
    return np.array(reps).reshape(-1)

seq_rep_dict = {k: convert_embedding(v) for k, v in seq_embedding_dict.items()}

In [None]:
### start simulation 
random_seeds = np.arange(1, 11)
top_fitness_scaled = []
median_fitness_scaled = []
top_n_pos_rate = []

for simulation_round, seed in tqdm(enumerate(random_seeds)):
    top_fitness_scaled_seed = []
    median_fitness_scaled_seed = []
    top_n_pos_rate_seed = []
    np.random.seed(seed)
    
    all_seqs = ordered_seq_neg + ordered_seq_pos
    X_train = []
    y_train = []
    sample_ids = set([int(_) for _ in np.random.randint(0, len(all_seqs), N_SAMPLE_PER_ROUND)])
    select_seqs = [all_seqs[i] for i in sample_ids]
    for r in range(NROUND):
        X_train += [seq_rep_dict[s] for s in select_seqs]
        y_train += [seq_fitness_dict[s] for s in select_seqs]
        # print(f'{r}: {len(select_seqs)}, {len(X_train)}, {len(sample_ids)}')
        
        x_scale, x_shift = np.std(X_train), np.mean(X_train)
        y_scale, y_shift = np.std(y_train), np.mean(y_train)

        X_train_ = (np.array(X_train) - x_shift) / x_scale 
        y_train_ = (np.array(y_train) - y_shift) / y_scale
        
        model = RandomForestRegressor(n_estimators=100, criterion='friedman_mse', max_depth=None, min_samples_split=2,
                            min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=1.0,
                            max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False,
                            n_jobs=16, random_state=1, verbose=0, warm_start=False, ccp_alpha=0.0,
                            max_samples=None)
        model.fit(X_train_, y_train_)
        predict_y = model.predict((np.array([seq_rep_dict[s] for s in all_seqs]) - x_shift) / x_scale) * y_scale + y_shift
        
        top_n_idx = [int(x) for x in np.argsort(predict_y)[::-1]]
        top_n_seqs = [all_seqs[top_n_idx[i]] for i in range(N_TOP_SEQS)]

        top_n_pos_rate_seed.append(np.sum([1 if s in ordered_seq_pos_set else 0 for s in top_n_seqs]) / N_SAMPLE_PER_ROUND)
        median_fitness_scaled_seed.append(np.median([seq_fitness_dict[s] for s in top_n_seqs]))
        top_fitness_scaled_seed.append(seq_fitness_dict[top_n_seqs[0]])
        
        f_result.writelines(f'{simulation_round+1},{r+1},{median_fitness_scaled_seed[-1]},{top_fitness_scaled_seed[-1]},{top_n_pos_rate_seed[-1]}\n')
        
        count = 0
        select_ids = []
        select_seqs = []
        for idx in top_n_idx:
            if not idx in sample_ids:
                count += 1
                select_ids.append(idx)
                select_seqs.append(all_seqs[idx])
            if count == N_SAMPLE_PER_ROUND:
                break 
            
        sample_ids = sample_ids | set([int(_) for _ in select_ids])
        
    top_n_pos_rate.append(top_n_pos_rate_seed)
    top_fitness_scaled.append(top_fitness_scaled_seed)
    median_fitness_scaled.append(median_fitness_scaled_seed)
    
f_result.close()

### Analysis

In [None]:
import matplotlib.pyplot as plt

### high activity candidate rate in top-n
fig, ax = plt.subplots(1, 1, figsize=(6,3), dpi=300)
ax.plot(np.arange(NROUND),
        np.mean(top_n_pos_rate, axis=0), marker='o')
ax.errorbar(
    x = np.arange(NROUND),
    y = np.mean(top_n_pos_rate, axis=0),
    yerr = np.std(top_n_pos_rate, axis=0), marker='o', alpha=0.5, 
    capsize = 5.0, )
ax.set_ylim(0.0, 1.0)
ax.set_xlabel('Round')
ax.set_ylabel('High Activity Candidate Rate')

In [None]:
import matplotlib.pyplot as plt

### top-1 fitness
fig, ax = plt.subplots(1, 1, figsize=(6, 3), dpi=300)
ax.plot(np.arange(NROUND),
        np.mean(top_fitness_scaled, axis=0), marker='o')
ax.errorbar(
    x = np.arange(NROUND),
    y = np.mean(top_fitness_scaled, axis=0),
    yerr = np.std(top_fitness_scaled, axis=0), marker='o', alpha=0.5, 
    capsize = 5.0, )

ax.set_xlabel('Round')
ax.set_ylabel('Top-1 Fitness')

In [None]:
import matplotlib.pyplot as plt

### meadian fitness top-1
fig, ax = plt.subplots(1, 1, figsize=(6, 3), dpi=300)
ax.plot(np.arange(NROUND),
        np.mean(median_fitness_scaled, axis=0), marker='o')
ax.errorbar(
    x = np.arange(NROUND),
    y = np.mean(median_fitness_scaled, axis=0),
    yerr = np.std(median_fitness_scaled, axis=0), marker='o', alpha=0.5, 
    capsize = 5.0, )

ax.set_xlabel('Round')
ax.set_ylabel('Top-n Median Fitness')