In [2]:
import os
import sys
import h5py
import json
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob

from src.dataset import StructuresDataset, collate_batch_features, select_by_sid, select_by_interface_types
from src.data_encoding import encode_structure, encode_features, extract_topology, categ_to_resnames, resname_to_categ
from src.structure import data_to_structure, encode_bfactor, concatenate_chains, split_by_chain
from src.structure_io import save_pdb, read_pdb
from src.scoring import bc_scoring, bc_score_names

In [3]:
# data parameters
data_path = "examples/issue_20_04_2023"

In [4]:
# model parameters
# R3
#save_path = "model/save/i_v3_0_2021-05-27_14-27"  # 89
#save_path = "model/save/i_v3_1_2021-05-28_12-40"  # 90
# R4
#save_path = "model/save/i_v4_0_2021-09-07_11-20"  # 89
save_path = "model/save/i_v4_1_2021-09-07_11-21"  # 91

# select saved model
model_filepath = os.path.join(save_path, 'model_ckpt.pt')
#model_filepath = os.path.join(save_path, 'model.pt')

In [5]:
# add module to path
if save_path not in sys.path:
    sys.path.insert(0, save_path)
    
# load functions
from config import config_model, config_data
from data_handler import Dataset
from model import Model

In [6]:
# define device
device = pt.device("cpu")

# create model
model = Model(config_model)

# reload model
model.load_state_dict(pt.load(model_filepath, map_location=pt.device("cpu")))

# set model to inference
model = model.eval().to(device)

In [7]:
# find pdb files and ignore already predicted oins
pdb_filepaths = glob(os.path.join('Test18PDB1', "*.pdb1"), recursive=True)
pdb_filepaths = [fp for fp in pdb_filepaths if "_i" not in fp]
# pdb_filepaths = ['1uud.pdb']

# create dataset loader with preprocessing
dataset = StructuresDataset(pdb_filepaths, with_preprocessing=True)

# debug print
print(len(dataset))

18


In [8]:
# run model on all subunits
with pt.no_grad():
    for subunits, filepath in tqdm(dataset):
        # concatenate all chains together
        structure = concatenate_chains(subunits)

        # encode structure and features
        X, M = encode_structure(structure)
        #q = pt.cat(encode_features(structure), dim=1)
        q = encode_features(structure)[0]

        # extract topology
        ids_topk, _, _, _, _ = extract_topology(X, 64)

        # pack data and setup sink (IMPORTANT)
        X, ids_topk, q, M = collate_batch_features([[X, ids_topk, q, M]])

        # run model
        z = model(X.to(device), ids_topk.to(device), q.to(device), M.float().to(device))

        # for all predictions
        for i in range(z.shape[1]):
            # prediction
            p = pt.sigmoid(z[:,i])
            # encode result
            structure = encode_bfactor(structure, p.cpu().numpy())

            # save results
            output_filepath = 'Test18predictions/' + filepath[:-4]+'_i{}.pdb'.format(i)
            save_pdb(split_by_chain(structure), output_filepath)

  0%|          | 0/18 [00:00<?, ?it/s]

Test18PDB1/379d.pdb1


  6%|▌         | 1/18 [00:11<03:09, 11.14s/it]

Test18PDB1/1fmn.pdb1


 11%|█         | 2/18 [00:21<02:49, 10.61s/it]

Test18PDB1/2tob.pdb1


 17%|█▋        | 3/18 [00:28<02:12,  8.86s/it]

Test18PDB1/1qdn.pdb1


 22%|██▏       | 4/18 [01:18<05:52, 25.16s/it]

Test18PDB1/2mis.pdb1


 28%|██▊       | 5/18 [01:25<04:04, 18.82s/it]

Test18PDB1/364d.pdb1


 33%|███▎      | 6/18 [01:40<03:28, 17.38s/it]

Test18PDB1/1nem.pdb1


 39%|███▉      | 7/18 [01:47<02:32, 13.87s/it]

Test18PDB1/5v3f.pdb1


 44%|████▍     | 8/18 [02:02<02:23, 14.35s/it]

Test18PDB1/430d.pdb1


 50%|█████     | 9/18 [02:10<01:52, 12.52s/it]

Test18PDB1/1ddy.pdb1


 56%|█████▌    | 10/18 [02:21<01:35, 11.91s/it]

Test18PDB1/6ez0.pdb1


 61%|██████    | 11/18 [02:28<01:12, 10.42s/it]

Test18PDB1/4pqv.pdb1


 67%|██████▋   | 12/18 [02:45<01:13, 12.30s/it]

Test18PDB1/2juk.pdb1


 72%|███████▏  | 13/18 [02:51<00:52, 10.47s/it]

Test18PDB1/1f1t.pdb1


 78%|███████▊  | 14/18 [03:00<00:39,  9.97s/it]

Test18PDB1/2pwt.pdb1


 83%|████████▎ | 15/18 [03:14<00:33, 11.12s/it]

Test18PDB1/4f8u.pdb1


 89%|████████▉ | 16/18 [03:25<00:22, 11.13s/it]

Test18PDB1/4yaz.pdb1


 94%|█████████▍| 17/18 [03:45<00:13, 13.80s/it]

Test18PDB1/5bjo.pdb1


100%|██████████| 18/18 [04:04<00:00, 13.58s/it]
