In [1]:
import sys
import os
sys.path.append('../')
from tqdm.auto import tqdm
from collections import defaultdict
import py3Dmol
import torch
from torch.utils import data as torch_data
from common.utils import cuda, read_model_state
from data_loaders.enzyme_rxn_dataloader import EnzymeReactionSiteTypeDatasetForInference, enzyme_rxn_collate_extract, enzyme_rxn_collate
from model_structure.enzyme_site_model import EnzymeActiveSiteClsModel




In [2]:
def enzyme_rxn_collate_extract_for_inference(batch):
    pdb_name = [x.pop('pdb_name') for x in batch]
    batch_data = enzyme_rxn_collate(batch)
    assert isinstance(batch_data, dict)
    if 'targets' in batch_data:
        if isinstance(batch_data['targets'], tuple):
            targets, size = batch_data['targets']
            batch_data['targets'] = targets
            batch_data['protein_len'] = size.view(-1)
    batch_data['pdb_name'] = pdb_name
    return batch_data

def init_dataset_and_model(dataset_path, model_checkpoint_path, device='cuda:0'):
    
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    dataset = EnzymeReactionSiteTypeDatasetForInference(path=dataset_path,
                                save_precessed=False,
                                debug=False,
                                verbose=1,
                                lazy=True,
                                nb_workers=12)
    _, _, test_dataset = dataset.split()
    test_dataloader = torch_data.DataLoader(
        test_dataset,
        batch_size=1,
        collate_fn=enzyme_rxn_collate_extract_for_inference,
        shuffle=False,
        num_workers=4)
    
    model = EnzymeActiveSiteClsModel(
        rxn_model_path='../checkpoints/reaction_attn_net/model-ReactionMGMTurnNet_train_in_uspto_at_2023-04-05-23-46-25', num_active_site_type=dataset.num_active_site_type, from_scratch=True)
    model_state, _ = read_model_state(model_save_path=model_checkpoint_path)
    model.load_state_dict(model_state)
    print('Loaded checkpoint from {}'.format(model_checkpoint_path))
    model.to(device)
    model.eval()
    
    return model, test_dataset, test_dataloader
    
def calculate_one_data(dataset, data_index):
    data_package = dataset[data_index]
    batch_one_data = enzyme_rxn_collate_extract([data_package])
    return batch_one_data


def inference(model, batch_one_data, device='cuda:0'):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        if device.type == "cuda":
            batch_one_data = cuda(batch_one_data, device=device)
        try:
            protein_node_logic, protein_mask = model(batch_one_data)
        except:
            print(f'erro in this data')
            return 
        pred = torch.argmax(protein_node_logic.softmax(-1), dim=-1)
        return pred
    

def show_structure(structure_path, 
                   structure_fname, 
                   site_labels, 
                   view_size=(450, 450), 
                   res_colors={
                    0: 'white',   # 非活性位点
                    1: 'red',     # Binding Site
                    2: 'yellow',     # Active Site
                    3: 'blue',     # Other Site
}):
    with open(os.path.join(structure_path, structure_fname)) as ifile:
        system = ''.join([x for x in ifile])
    
    view = py3Dmol.view(width=view_size[0], height=view_size[1])
    view.addModelsAsFrames(system)
    
    i = 0
    for line in system.split("\n"):
        split = line.split()
        if len(split) == 0 or split[0] != "ATOM":
            continue
        res_idx = int(line[22:26].strip()) - 1
        color = res_colors[site_labels[res_idx]]
        view.setStyle({'model': -1, 'serial': i+1}, {"cartoon": {'color': color}})
        atom_name = line[12:16].strip()
        if (atom_name == 'CA') and (site_labels[res_idx] !=0) :
            residue_name = line[17:20].strip()
            x = float(line[30:38])
            y = float(line[38:46])
            z = float(line[46:54])
            view.addLabel(f'{residue_name} {res_idx}', {"fontSize": 15, "position": {"x": x, "y": y, "z": z}, "fontColor": color, "fontOpacity":1.0 ,"backgroundColor": 'white', "backgroundOpacity": 0.0})
        i += 1
    # view.addSurface(py3Dmol.SAS, {'opacity': 0.5})
    view.zoomTo()
    # view.zoom(1.2, 1000)
    view.show()

def show_structures(structure_path, 
                   structure_fname, 
                   gt_site_labels, 
                   pred_site_labels,
                   view_size=(450, 450), 
                   res_colors={
                    0: '#73B1FF',   # 非活性位点
                    1: '#FF0000',     # Binding Site
                    2: 'green',     # Active Site
                    3: '#FFFF00',     # Other Site
}):
    print('#'*20 + ' Ground Truth Active Site ' + '#'*20)
    show_structure(structure_path, structure_fname, gt_site_labels, view_size=view_size, res_colors=res_colors)
    print('#'*20 + ' Predicted Active Site ' + '#'*20)
    show_structure(structure_path, structure_fname, pred_site_labels, view_size=view_size, res_colors=res_colors)

def show_one_with_index(data_index, model, dataset):
    batch_one_data = calculate_one_data(dataset, data_index)
    pred_active_labels = inference(model, batch_one_data=batch_one_data)

    show_structures(structure_path='../dataset/ec_site_dataset/structures/alphafolddb_download',
                structure_fname=batch_one_data['pdb_name'][0],
                gt_site_labels=batch_one_data['targets'].int().tolist(),
                pred_site_labels=pred_active_labels.tolist(), view_size=(900, 900)
                )


In [3]:
model, test_dataset, _ = init_dataset_and_model(
    dataset_path='dataset/ec_site_dataset/uniprot_ecreact_cluster_split_merge_dataset_limit_100',
    model_checkpoint_path='../checkpoints/enzyme_site_type_predition_model/train_in_uniprot_ecreact_cluster_split_merge_dataset_limit_100_at_2023-06-14-11-04-55/global_step_70000'
)

  0%|          | 0/2595 [00:00<?, ?it/s]

  0%|          | 0/1011 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Loading structure check file ...

AA_sequence from structure file valid 108250/108549, 99.72%

Loading reaction check file ...

Reaction from csv file valid 108549/108549, 100.00%



Loading /home/xiaoruiwang/data/ubuntu_work_beta/single_step_work/ec_site_prediction/dataset/ec_site_dataset/structures/uniprot_ecreact_cluster_split_merge_dataset_limit_100_alphafolddb_proprecessed.pkl.gz: 100%|██████████| 108250/108250 [00:00<00:00, 577984.49it/s]


Loading /home/xiaoruiwang/data/ubuntu_work_beta/single_step_work/ec_site_prediction/dataset/ec_site_dataset/un…



Train reaction attention model from scratch...
Loaded checkpoint from ../checkpoints/enzyme_site_type_predition_model/train_in_uniprot_ecreact_cluster_split_merge_dataset_limit_100_at_2023-06-14-11-04-55/global_step_70000


In [4]:
data_index=856
show_one_with_index(data_index, model=model, dataset=test_dataset)

#################### Ground Truth Active Site ####################


#################### Predicted Active Site ####################
