## Finding ProteinMPNN Relevant Code

Preparing simplified functions from the PMPNN script

In [1]:
import numpy as np
import torch
from collections import namedtuple

from ProteinMPNN.protein_mpnn_utils import tied_featurize, parse_PDB, parse_fasta, _scores
from ProteinMPNN.protein_mpnn_utils import StructureDatasetPDB, ProteinMPNN

  from .autonotebook import tqdm as notebook_tqdm


#### Scoring code

Input:
    pdb_path
    fasta_path (optional)

Output:
    logprob

In [2]:
#arguments
class Arguments():
    def __init__(self,
                 pdb_path: str, # Path to a single PDB to be designed
                 fasta_path: str="", # Path to file containing one sequence to be scored in fasta format - currently incompatible with multiple sequences
                 ca_only: bool=False, # Parse CA-only structures and use CA-only models
                 backbone_noise: float=0, # Standard deviation of Gaussian noise to add to backbone atoms
                 max_length: int=200000, # Max sequence length
                 model_path: str='ProteinMPNN/vanilla_model_weights/v_48_020.pt' # Path to model weights folder
                ):

        self.pdb_path = pdb_path
        self.fasta_path = fasta_path
        self.ca_only = ca_only
        self.backbone_noise = backbone_noise
        self.max_length = max_length
        self.model_path = model_path

# args = Arguments(pdb_path="examples/pdbs/5L33.pdb",
#                  fasta_path="examples/fastas/5L33-mut_seq.fasta")
# args = Arguments(pdb_path="examples/pdbs/6MRR.pdb",
#                  fasta_path="examples/fastas/6MRR-mut_seq.fasta")

In [3]:
def pmpnn_score_pdb_seq(args):
    
    #data objects to initialize
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    alphabet_dict = dict(zip(alphabet, range(21)))
    
    
    # initialize the model
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

    ## load the checkpoint
    checkpoint = torch.load(args.model_path, map_location=device)

    hidden_dim = 128
    num_layers = 3

    model = ProteinMPNN(ca_only=args.ca_only,
                        num_letters=21,
                        node_features=hidden_dim,
                        edge_features=hidden_dim,
                        hidden_dim=hidden_dim,
                        num_encoder_layers=num_layers,
                        num_decoder_layers=num_layers,
                        augment_eps=args.backbone_noise,
                        k_neighbors=checkpoint['num_edges'])
    
    
    # prepare pdb input for tied_featurize: sequence, coordinates, metadata extracted from pdb and saved to dict
    pdb_dict_list = parse_PDB(args.pdb_path, ca_only=args.ca_only)
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=args.max_length)

    batch_clones = [dataset_valid[0]] # if only one pdb passed as input and batch size = 1
    
    
    # load variables to pass to model to get log_probs
    tf_output = tied_featurize(batch=batch_clones,
                               device=device,
                               chain_dict=None,
                               fixed_position_dict=None,
                               omit_AA_dict=None,
                               tied_positions_dict=None,
                               pssm_dict=None,
                               bias_by_res_dict=None,
                               ca_only=args.ca_only)

    ## save return (20 outputs) to named tuple
    tfOutputTuple = namedtuple("tfOutputTuple", ["X", "S", "mask", "lengths", "chain_M",
                                                 "chain_encoding_all", "chain_list_list",
                                                 "visible_list_list", "masked_list_list",
                                                 "masked_chain_length_list_list", "chain_M_pos",
                                                 "omit_AA_mask", "residue_idx", "dihedral_mask",
                                                 "tied_pos_list_of_lists_list", "pssm_coef",
                                                 "pssm_bias", "pssm_log_odds_all", "bias_by_res_all",
                                                 "tied_beta"])
    tf = tfOutputTuple(*tf_output)
    
    
    # read in sequence from fasta if given
    if args.fasta_path:
        fasta_names, fasta_seqs = parse_fasta(args.fasta_path, omit=["/"])
        assert len(fasta_seqs) == 1 ## currently only compatible with one pdb in, one pdb out
        fasta_seq = fasta_seqs[0]
        input_seq_length = len(fasta_seq)

        # update tf.S to be input sequence – otherwise is sequence read from pdb
        S_input = torch.tensor([alphabet_dict[AA] for AA in fasta_seq], device=device)[None,:].repeat(tf.X.shape[0], 1)
        tf.S[:,:input_seq_length] = S_input #assumes that S and S_input are alphabetically sorted for masked_chains

    ## TO DO: compatability with scoring multiple sequences

    
    # score sequence for pdb (log probs)
    randn_1 = torch.randn(tf.chain_M.shape, device=tf.X.device)
    # get log probs
    log_probs = model(tf.X, tf.S, tf.mask, tf.chain_M*tf.chain_M_pos, tf.residue_idx, tf.chain_encoding_all, randn_1)
    mask_for_loss = tf.mask*tf.chain_M*tf.chain_M_pos
    scores = _scores(tf.S, log_probs, mask_for_loss)
    
    return scores

In [4]:
# scoring pdb given a sequence
pmpnn_score_pdb_seq(args=Arguments(pdb_path="examples/pdbs/5L33.pdb",
                    fasta_path="examples/fastas/5L33-mut_seq.fasta")
    )

tensor([3.5213], grad_fn=<DivBackward0>)

In [5]:
# scoring pdb given a sequence
pmpnn_score_pdb_seq(args=Arguments(pdb_path="examples/pdbs/6MRR.pdb",
                    fasta_path="examples/fastas/6MRR-mut_seq.fasta")
    )

tensor([4.0384], grad_fn=<DivBackward0>)

In [6]:
# scoring pdb given no sequence – score sequence in pdb
pmpnn_score_pdb_seq(args=Arguments(pdb_path="examples/pdbs/5L33.pdb"))

tensor([3.4667], grad_fn=<DivBackward0>)

In [7]:
# scoring pdb given no sequence – score sequence in pdb
pmpnn_score_pdb_seq(args=Arguments(pdb_path="examples/pdbs/6MRR.pdb"))

tensor([3.5551], grad_fn=<DivBackward0>)