In [1]:
import numpy as np
import os
import joblib
import jax
import jax.numpy as jnp
import re
import copy
import random
import haiku as hk
from tqdm import tqdm
from matplotlib import pyplot as plt

# from protein_mpnn_utils import loss_nll, loss_smoothed, _scores, _S_to_seq, StructureDataset
# from protein_mpnn_utils import gather_edges
# from protein_mpnn_utils import gather_nodes, gather_nodes_t, cat_neighbors_nodes
from utils import tied_featurize, parse_PDB, StructureDatasetPDB
from colabdesign.mpnn.modules import RunModel

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# parse files

In [3]:
aa3 = ["GLY", "ALA", "VAL", "LEU", "ILE", "PRO", "PHE", "TYR", "TRP", "SER",
       "THR", "CYS", "MET", "ASN", "GLN", "ASP", "GLU", "LYS", "ARG", "HIS"]
aa = ['G', 'A', 'V', 'L', 'I', 'P', 'F', 'Y', 'W', 'S', 'T', 'C', 'M', 'N', 'Q', 'D', 'E', 'K', 'R', 'H']
c3_1 = {aa3[i]: aa[i] for i in range(len(aa3))}
c1_3 = {aa[i]: aa3[i] for i in range(len(aa3))}

In [4]:
def parse_pdb(pdb_f, chainID="A"):
    '''parse pdb file

    Args:
        pdb_f - string, pdb file.
        chainID - string, chain ID. (optional)
    Returns:
        pdb_seq - string, sequence in pdb
        pdb_idx - list, residue numbers in pdb
        pdb_coord - dict, coordinates of C, CA, CB, N,
                    if an atom does not have coordinate, it is set to [-1,-1,-1]
        pdb_coord_all - list, coordinates of all atoms in an AA,
                        if an AA does not have a atom with coordinate,
                        it is set to an empty list
    '''
    with open(pdb_f, 'r') as f:
        lines = f.readlines()
    pdb_seq = []
    pdb_idx = []
    pdb_coord = {'C': [], 'CA': [], 'CB': [], 'N': []}
    pdb_coord_all = []
    cur = -np.inf
    for line in lines:
        if (line.startswith('ATOM')
                and line[21] == chainID):
            if cur != int(line[22:26]):
                cur = int(line[22:26])
                pdb_seq.append(c3_1.get(line[17:20], 'U'))
                pdb_idx.append(int(line[22:26]))
                for atom in pdb_coord.keys():
                    pdb_coord[atom].append([-1, -1, -1])
                pdb_coord_all.append([])
            atom = line[12:16].strip()
            if atom in pdb_coord.keys():
                pdb_coord[atom][-1] = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
            pdb_coord_all[-1].append([float(line[30:38]), float(line[38:46]), float(line[46:54])])

    for atom in pdb_coord.keys():
        pdb_coord[atom] = np.array(pdb_coord[atom])
    pdb_seq = ''.join(pdb_seq)
    return pdb_seq, pdb_idx, pdb_coord, pdb_coord_all

In [5]:
def mk_msa(seqs, alphabet):
    '''one hot encode msa

    Args:
        seqs - list, msa sequences in character
        alphabet - string, alphabet used in seqs

    Returns:
        msa_ori - array, msa in number
        np.eye(states)[msa_ori] - array, one hot msa
    '''
    states = len(alphabet)
    gap_ind = alphabet.find('-')
    a2n = {a: n for n, a in enumerate(alphabet)}
    msa_ori = np.array([[a2n.get(aa, gap_ind) for aa in seq] for seq in seqs])
    return msa_ori, np.eye(states)[msa_ori]

In [6]:
norm_ab = 'ACDEFGHIKLMNPQRSTVWY-'

In [7]:
pdb_seq, pdb_idx, pdb_coord, pdb_coord_all = parse_pdb('./1P3J.pdb', 'A')

In [8]:
pdbseq_oh = mk_msa([pdb_seq], alphabet=norm_ab)[1][0]

# model

In [9]:
class MPNN_wrapper:
    def __init__(self, params_path, model_name):
        self.params_path = params_path
        self.model_name = model_name  # @param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]

        backbone_noise=0.00  # Standard deviation of Gaussian noise to add to backbone atoms
        hidden_dim = 128
        num_layers = 3 

        if self.params_path[-1] != '/':
            self.params_path += '/'
        checkpoint_path = self.params_path + f'{self.model_name}.pkl'

        checkpoint = joblib.load(checkpoint_path)
        params = jax.tree_util.tree_map(jnp.array, checkpoint['model_state_dict'])
        print('Number of edges:', checkpoint['num_edges'])
        noise_level_print = checkpoint['noise_level']
        print(f'Training noise level: {noise_level_print}A')

        config = {'num_letters': 21,
                'node_features': hidden_dim,
                'edge_features': hidden_dim,
                'hidden_dim': hidden_dim,
                'num_encoder_layers': num_layers,
                'num_decoder_layers': num_layers,
                'augment_eps': backbone_noise,
                'k_neighbors': checkpoint['num_edges'],
                'dropout': 0.0
                }

        model = RunModel(config)
        model.params = params
        self.model = model

        self.alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
        self.max_length = 20000
    
    def set_pred_params(self, pdb_path, designed_chain,
                        fixed_chain='', ishomomer=False,
                        sampling_temp=0.1, omit_AAs='X'):
        self.pdb_path = pdb_path
        self.designed_chain = designed_chain
        self.fixed_chain = fixed_chain
        self.ishomomer = ishomomer
        self.batch_size = 1
        self.sampling_temp = sampling_temp

        # design chains
        if designed_chain == '':
            self.designed_chain_list = []
        else:
            self.designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")
        
        # fixed chains
        if fixed_chain == "":
            self.fixed_chain_list = []
        else:
            self.fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
        
        #chain list
        self.chain_list = list(set(self.designed_chain_list + self.fixed_chain_list))

        # omit AAs
        omit_AAs_list = omit_AAs
        self.omit_AAs_np = np.array([AA in omit_AAs_list for AA in self.alphabet]).astype(np.float32)
    
        # prepare input
        pdb_dict_list = parse_PDB(self.pdb_path, input_chain_list=self.chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=self.max_length)

        chain_id_dict = {}
        chain_id_dict[pdb_dict_list[0]['name']]= (self.designed_chain_list, self.fixed_chain_list)

        if self.ishomomer:
            tied_positions_dict = self.make_tied_positions_for_homomers(pdb_dict_list)
        else:
            tied_positions_dict = None
    
    def prep_logits_input(self, pdb_path, designed_chain,
                          fixed_chain='', ishomomer=False):
        if designed_chain == '':
            designed_chain_list = []
        else:
            designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

        if fixed_chain == '':
            fixed_chain_list = []
        else:
            fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
        
        chain_list = list(set(designed_chain_list + fixed_chain_list))

        pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
        dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=self.max_length)
        chain_id_dict = {}
        chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

        if ishomomer:
            tied_positions_dict = self.make_tied_positions_for_homomers(pdb_dict_list)
        else:
            tied_positions_dict = None

        protein = dataset_valid[0]
        batch_clones = [copy.deepcopy(protein)]
        fixed_positions_dict = None
        omit_AA_dict = None
        pssm_dict = None
        bias_by_res_dict = None

        (X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list,
         visible_list_list, masked_list_list, masked_chain_length_list_list,
         chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask,
         tied_pos_list_of_lists_list, pssm_coef, pssm_bias,
         pssm_log_odds_all, bias_by_res_all, tied_beta) = tied_featurize(batch_clones,
                                                                         chain_id_dict, fixed_positions_dict,
                                                                         omit_AA_dict, tied_positions_dict,
                                                                         pssm_dict, bias_by_res_dict)
        
        return {'X': X,
                 'S': S,
                 'mask': mask,
                 'chain_M': chain_M*chain_M_pos,
                 'residue_idx': residue_idx,
                 'chain_encoding_all': chain_encoding_all}
    
    def get_logits(self, decode_order, input, seq=None):
        if seq is not None:
            S = np.asarray([self.alphabet.index(a) for a in seq], dtype=np.int32)
            S = S[None, :]
            input['S'] = jnp.array(S)
        input['randn'] = jnp.expand_dims(decode_order, 0)     
        key = jax.random.PRNGKey(0)
        return self.model.apply(self.model.params, key, input)[0]

    @staticmethod
    def make_tied_positions_for_homomers(pdb_dict_list):
        my_dict = {}
        for result in pdb_dict_list:
            all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain'])  # A, B, C, ...
            tied_positions_list = []
            chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
            for i in range(1,chain_length+1):
                temp_dict = {}
                for j, chain in enumerate(all_chain_list):
                    temp_dict[chain] = [i] #needs to be a list
                tied_positions_list.append(temp_dict)
            my_dict[result['name']] = tied_positions_list
        return my_dict

In [10]:
wrapper = MPNN_wrapper(params_path='./jax_weights',
                       model_name='v_48_020')

Number of edges: 48
Training noise level: 0.2A


## get_logits

In [None]:
H_batch = []
config = wrapper.prep_logits_input(pdb_path='./1P3J.pdb',
                                   designed_chain='A')

seed = random.randint(0,2147483647)
order = jax.random.normal(jax.random.PRNGKey(seed), (len(pdb_seq),))
logits = wrapper.get_logits(order, config)