<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.0.9/colabdesign/mpnn/test_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title install
%%bash
pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.0.9
# for debugging
ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign

In [1]:
import numpy as np
import os
import joblib
import jax
import jax.numpy as jnp
import re
import copy
import random
import haiku as hk
from tqdm import tqdm
from matplotlib import pyplot as plt

from colabdesign.mpnn.utils import tied_featurize, parse_PDB, StructureDatasetPDB
from colabdesign.mpnn.modules import RunModel

# model

In [2]:
class MPNN_wrapper:
    def __init__(self, params_path="/content/colabdesign/mpnn/jax_weights", model_name="v_48_002"):
        self.params_path = params_path
        self.model_name = model_name

        backbone_noise=0.00  # Standard deviation of Gaussian noise to add to backbone atoms
        hidden_dim = 128
        num_layers = 3 

        if self.params_path[-1] != '/':
            self.params_path += '/'
        checkpoint_path = self.params_path + f'{self.model_name}.pkl'

        checkpoint = joblib.load(checkpoint_path)
        params = jax.tree_util.tree_map(jnp.array, checkpoint['model_state_dict'])
        print('Number of edges:', checkpoint['num_edges'])
        noise_level_print = checkpoint['noise_level']
        print(f'Training noise level: {noise_level_print}A')

        config = {'num_letters': 21,
                'node_features': hidden_dim,
                'edge_features': hidden_dim,
                'hidden_dim': hidden_dim,
                'num_encoder_layers': num_layers,
                'num_decoder_layers': num_layers,
                'augment_eps': backbone_noise,
                'k_neighbors': checkpoint['num_edges'],
                'dropout': 0.0
                }

        model = RunModel(config)
        model.params = params
        self.model = model

        self.alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
        self.max_length = 20000
    
    def set_pred_params(self, pdb_path, designed_chain,
                        fixed_chain='', ishomomer=False,
                        sampling_temp=0.1, omit_AAs='X'):
        self.pdb_path = pdb_path
        self.designed_chain = designed_chain
        self.fixed_chain = fixed_chain
        self.ishomomer = ishomomer
        self.batch_size = 1
        self.sampling_temp = sampling_temp

        # design chains
        if designed_chain == '':
            self.designed_chain_list = []
        else:
            self.designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")
        
        # fixed chains
        if fixed_chain == "":
            self.fixed_chain_list = []
        else:
            self.fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
        
        #chain list
        self.chain_list = list(set(self.designed_chain_list + self.fixed_chain_list))

        # omit AAs
        omit_AAs_list = omit_AAs
        self.omit_AAs_np = np.array([AA in omit_AAs_list for AA in self.alphabet]).astype(np.float32)
    
        # prepare input
        pdb_dict_list = parse_PDB(self.pdb_path, input_chain_list=self.chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=self.max_length)

        chain_id_dict = {}
        chain_id_dict[pdb_dict_list[0]['name']]= (self.designed_chain_list, self.fixed_chain_list)

        if self.ishomomer:
            tied_positions_dict = self.make_tied_positions_for_homomers(pdb_dict_list)
        else:
            tied_positions_dict = None
    
    def prep_logits_input(self, pdb_path, designed_chain,
                          fixed_chain='', ishomomer=False):
        if designed_chain == '':
            designed_chain_list = []
        else:
            designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

        if fixed_chain == '':
            fixed_chain_list = []
        else:
            fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
        
        chain_list = list(set(designed_chain_list + fixed_chain_list))

        pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=self.max_length)
        chain_id_dict = {}
        chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

        if ishomomer:
            tied_positions_dict = self.make_tied_positions_for_homomers(pdb_dict_list)
        else:
            tied_positions_dict = None

        protein = dataset_valid[0]
        batch_clones = [copy.deepcopy(protein)]
        fixed_positions_dict = None
        omit_AA_dict = None
        pssm_dict = None
        bias_by_res_dict = None

        (X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list,
         visible_list_list, masked_list_list, masked_chain_length_list_list,
         chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask,
         tied_pos_list_of_lists_list, pssm_coef, pssm_bias,
         pssm_log_odds_all, bias_by_res_all, tied_beta) = tied_featurize(batch_clones,
                                                                         chain_id_dict, fixed_positions_dict,
                                                                         omit_AA_dict, tied_positions_dict,
                                                                         pssm_dict, bias_by_res_dict)
        
        return {'X': X,
                 'S': S,
                 'mask': mask,
                 'chain_M': chain_M*chain_M_pos,
                 'residue_idx': residue_idx,
                 'chain_encoding_all': chain_encoding_all}
    
    def get_logits(self, decode_order, input, seq=None):
        if seq is not None:
            S = np.asarray([self.alphabet.index(a) for a in seq], dtype=np.int32)
            S = S[None, :]
            input['S'] = jnp.array(S)
        input['randn'] = jnp.expand_dims(decode_order, 0)     
        key = jax.random.PRNGKey(0)
        return self.model.apply(self.model.params, key, input)[0]

    @staticmethod
    def make_tied_positions_for_homomers(pdb_dict_list):
        my_dict = {}
        for result in pdb_dict_list:
            all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain'])  # A, B, C, ...
            tied_positions_list = []
            chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
            for i in range(1,chain_length+1):
                temp_dict = {}
                for j, chain in enumerate(all_chain_list):
                    temp_dict[chain] = [i] #needs to be a list
                tied_positions_list.append(temp_dict)
            my_dict[result['name']] = tied_positions_list
        return my_dict

In [3]:
wrapper = MPNN_wrapper()



Number of edges: 48
Training noise level: 0.02A


## get_logits

In [4]:
H_batch = []
config = wrapper.prep_logits_input(pdb_path='/content/colabdesign/mpnn/1P3J.pdb',designed_chain='A')

seed = random.randint(0,2147483647)
order = jax.random.normal(jax.random.PRNGKey(seed), (len(config["residue_idx"][0]),))
logits = wrapper.get_logits(order, config)

In [5]:
logits

DeviceArray([[[ 0.39064622, -1.0170941 , -0.8448451 , ..., -1.3881565 ,
               -0.5436939 , -2.0760546 ],
              [-1.0061326 , -0.29365587, -0.62056535, ..., -0.44944283,
                0.54547304, -1.0063171 ],
              [-0.52949464, -0.24696444, -1.0139545 , ..., -0.8602449 ,
               -1.0075833 , -0.9707053 ],
              ...,
              [ 0.07505904, -1.871599  ,  1.0982891 , ..., -2.1363168 ,
               -1.7764524 , -2.2191148 ],
              [-0.51556355, -1.1978966 , -1.0540409 , ..., -1.3301657 ,
               -0.21174452, -1.8705697 ],
              [-0.66337556,  0.18112892, -0.86197484, ..., -1.3362936 ,
               -1.0082724 , -1.3131907 ]]], dtype=float32)