In [None]:
import sys, os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from constant import *

sys.path.append(mpnn_dir)


import torch, scipy
from protein_mpnn_utils import parse_PDB, tied_featurize
from protein_mpnn_utils import ProteinMPNN
import gc
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# alphabet of ProteinMPNN (20 standard + X for unknown)
AAs = 'ACDEFGHIKLMNPQRSTVWYX'
aa2id_np = {aa:i for i,aa in enumerate(AAs)}

# start and end index of designed VR4
start, end = 233,245

In [None]:
def load_model(model_path=f'{mpnn_dir}/vanilla_model_weights/v_48_020.pt', device=device):
    """Load ProteinMPNN model."""
    checkpoint = torch.load(model_path, map_location='cpu')
    model = ProteinMPNN(
        num_letters=21,
        node_features=128,
        edge_features=128,
        hidden_dim=128,
        num_encoder_layers=3,
        num_decoder_layers=3,
        vocab=21,
        k_neighbors=checkpoint['num_edges']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model.to(device)

def get_MPNN_probabilities(model, pdb_path, mask_all_VR4=True, use_backbone_only=False):
    """
    Extract amino acid probabilities for each position.
    Inputs:
        model: ProteinMPNN model
        pdb_path: path to PDB file
        mask_all_VR4: if True, mask all residues in VR4; else, mask only linker residues
        use_backbone_only: if True, use backbone atoms only for conditional probabilities
    Returns:
        dict with keys:
            'autoregressive', 'conditional', 'unconditional': amino acid probabilities from ProteinMPNN (L, 20)
            'sequence': VR4 sequence
    """
    gc.collect()
    # Parse PDB file
    pdb_dict_list = parse_PDB(pdb_path)
    pdb_dict = pdb_dict_list[0]
    pdb_name = pdb_dict['name']
    seq_A = pdb_dict['seq_chain_A']
    seq_B = pdb_dict['seq_chain_B']
    
    chain_dict = {
        pdb_name: (["A"], ["B"])   # masked_chains, visible_chains
    }

    # mask only the designed VR4
    if mask_all_VR4:
        fixed_position_chainA = [i+1 for i in range(len(seq_A)) if not (233 <= i <= 244)]
    else:
        fixed_position_chainA = [i+1 for i in range(len(seq_A)) if i not in [233,234,243,244]]
        
    fixed_position_dict = {
        pdb_name: {
            "A": fixed_position_chainA,
            "B": [i+1 for i in range(len(seq_B))]
        }
    }
    
    # Featurize
    batch = tied_featurize(
        batch = [pdb_dict], 
        device = device, 
        chain_dict = chain_dict, 
        fixed_position_dict = fixed_position_dict, 
    )
    
    X = batch[0]
    S = batch[1]
    mask = batch[2]
    chain_M = batch[4]
    chain_encoding_all = batch[5]
    chain_M_pos = batch[10]
    residue_idx = batch[12]
    
    # Random noise for conditional probs
    randn_1 = torch.randn(chain_M.shape, device=X.device)
    
    # Get logits
    with torch.no_grad():
    
        # P(AA_i | structure, AA_1, ..., AA_{i-1})
        log_probs_autoregressive = model(X, S, mask, chain_M*mask, residue_idx, chain_encoding_all, randn_1)
        probs_autoregressive = torch.exp(log_probs_autoregressive).squeeze(0).cpu().numpy()
    
        # P(AA_i | structure, all other native amino acids)        
        log_probs_conditional = model.conditional_probs(
            X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all,
            randn_1, use_backbone_only
        )
        probs_conditional = torch.exp(log_probs_conditional).squeeze(0).cpu().numpy()
    
        # P(AA_i | structure) independently for each position
        log_probs_unconditional = model.unconditional_probs(
            X, mask, residue_idx, chain_encoding_all
        )
        probs_unconditional = torch.exp(log_probs_unconditional).squeeze(0).cpu().numpy()

    return {
        'autoregressive': probs_autoregressive,
        'conditional': probs_conditional,
        'unconditional': probs_unconditional,
        'sequence': seq_A
    }

def get_prob_VR4(
    pdb_path, mask_all_VR4, use_backbone_only
):
"""
    Get average amino acid probabilities for VR4 region.
    Inputs:
        pdb_path: path to PDB file
        mask_all_VR4: if True, mask all residues in VR4; else, mask only linker residues
        use_backbone_only: if True, use backbone only for conditional probabilities
    Returns:
        p1, p2, p3: average probabilities for VR4 region from autoregressive, conditional, and unconditional models
"""
    probs_dict = get_MPNN_probabilities(
        model, 
        pdb_path, 
        mask_all_VR4=mask_all_VR4, 
        use_backbone_only=use_backbone_only
    )
    seq_A = probs_dict['sequence']
    pep = seq_A[start:end]

    scores = []
    for m in ['autoregressive', 'conditional', 'unconditional']:
        p = probs_dict[m][start:end, :20]
        p_pep = [float(p[i, aa2id_np[pep[i]]]) for i in range(len(pep))]
        if mask_all_VR4:
            p_pep = np.average(p_pep)
        else:
            p_pep = np.average([p for i,p in enumerate(p_pep) if i in [0,1,len(pep)-2,len(pep)-1]])
        scores.append(float(p_pep))
    p1,p2,p3 = scores
    return p1,p2,p3

In [5]:
model = load_model()

In [None]:
for p in list(motifs.keys()):
    df = pd.read_csv(f'score_rosetta_ESM2_MPNN-BBonly_{p}.csv')
    scores = []
    for f in tqdm(df.description.tolist()):
        pdb_path = f"{dir_out_rosetta}/AAV9_{p}/output_design/{f[:-5]}.pdb"
        score = get_prob_VR4(pdb_path, mask_all_VR4=True, use_backbone_only=True)
        scores.append(score)
    p1,p2,p3 = zip(*scores)
    
    df['MPNN_autoregressive'] = p1
    df['MPNN_conditional'] = p2
    df['MPNN_unconditional'] = p3
    df.to_csv(f'score_rosetta_ESM2_MPNN-BBonly_{p}.csv', index=False)