# MapDiff sequence inference from a pdb file 

In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from hydra import initialize, compose
from model.egnn_pytorch.egnn_net import EGNN_NET
from model.ipa.ipa_net import IPANetPredictor
from model.prior_diff import Prior_Diff
from utils import enable_dropout
from dataloader.collator import CollatorDiff
from data.generate_graph_cath import pdb2graph, get_processed_graph, amino_acids_type
from tqdm import tqdm

# Initialize config and load trained model weights

In [None]:
with initialize(version_base=None, config_path="conf"):
    cfg = compose(config_name="inference")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load trained model
egnn = EGNN_NET(input_feat_dim=cfg.model.input_feat_dim, hidden_channels=cfg.model.hidden_dim,
                edge_attr_dim=cfg.model.edge_attr_dim,
                dropout=cfg.model.drop_out, n_layers=cfg.model.depth, update_edge=cfg.model.update_edge,
                norm_coors=cfg.model.norm_coors, update_coors=cfg.model.update_coors,
                update_global=cfg.model.update_global, embedding=cfg.model.embedding,
                embedding_dim=cfg.model.embedding_dim, norm_feat=cfg.model.norm_feat, embed_ss=cfg.model.embed_ss)

ipa = IPANetPredictor(dropout=cfg.model.ipa_drop_out)
model = Prior_Diff(egnn, ipa, timesteps=cfg.diffusion.timesteps,
                   objective=cfg.diffusion.objective,
                   noise_type=cfg.diffusion.noise_type, sample_method=cfg.diffusion.sample_method,
                   min_mask_ratio=cfg.mask_prior.min_mask_ratio,
                   dev_mask_ratio=cfg.mask_prior.dev_mask_ratio,
                   marginal_dist_path=cfg.dataset.marginal_train_dir).to(device)

checkpoint = torch.load(cfg.test_model.path)
model.load_state_dict(checkpoint['model'], strict=True)

# MapDiff sequence inference

In [None]:
# load data
pdb_dir = "data/sample_pdb/"
pdb_files = os.listdir(pdb_dir)

for pdb_file in tqdm(pdb_files):
    graph = get_processed_graph(pdb2graph(pdb_dir + pdb_file))
    collator = CollatorDiff()
    g_batch, ipa_batch = collator([graph])
    g_batch, ipa_batch = g_batch.to(device), ipa_batch.to(device)

    # predict
    model.eval()
    with torch.no_grad():
        ens_logits = []
        enable_dropout(model)
        for _ in range(cfg.diffusion.ensemble_num):
            logits, sample_graph = model.mc_ddim_sample(g_batch, ipa_batch, diverse=True, step=cfg.diffusion.ddim_steps)
            ens_logits.append(logits)
        ens_logits_tensor = torch.stack(ens_logits)
        mean_sample_logits = ens_logits_tensor.mean(dim=0).cpu()
        true_label = g_batch.x.cpu()
        true_sample_seq = ''.join([amino_acids_type[i] for i in true_label.argmax(dim=1).tolist()])
        pred_sample_seq = ''.join([amino_acids_type[i] for i in mean_sample_logits.argmax(dim=1).tolist()])
    
        ll_fullseq = F.cross_entropy(mean_sample_logits, true_label, reduction='mean').item()
        perplexity = np.exp(ll_fullseq)
        sample_recovery = (mean_sample_logits.argmax(dim=1) == true_label.argmax(dim=1)).sum() / true_label.shape[0]
        
        print(f'PDB file: {pdb_file}')
        print(f'Sequence length: {len(pred_sample_seq)}')
        print(f'Sample perplexity: {perplexity}')
        print(f'Sample recovery rate {sample_recovery:.4f}')
        print(f'Pred sequence: {pred_sample_seq}')
        print(f'True sequence: {true_sample_seq}')