**For Predicting**
===
> -*- coding: utf-8 -*-

> Author: xinghd

In [None]:
from utils.utils import Trainer, Predictor, TrainReader,TestReader,PredictReader, data_collate, collate_fn, seed_everything
from model.model import Encoder, Decoder, ModelCat

%reload_ext autoreload
%autoreload 2

In [None]:
import os
import json
import random
import time
import timeit
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from Bio import SeqIO
from torch import nn
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import torch
from torch.utils.data import DataLoader, random_split
from utils.word2vec import seq_to_kmers, get_protein_embedding
from gensim.models import Word2Vec
from torch.utils import data as torch_data
import warnings

In [None]:
seed_everything(seed=42)

In [None]:
assert torch.cuda.is_available(), "Must have a avaliable gpu"

class PreCFG :
    useFasteFile = False
    csvFile = './data/example2.csv' # Path to the .csv file of enzyme-substrate pairs
     
    if useFasteFile == True:
        fasteFile = './data/example1.faste' # Path to the .faste file of candidate enzymes
        compound = '' # The molecular structure of the substrate in simplified molecular input line entry system (SMILES) format
    
    # Data root
    DATA_ROOT = r'./data/'
    # Word2Vec model path
    word2vec_path = './model/model_pretrained/word2vec_pretrained.model' 
    # Model path
    state_dict_path = './model/model_pretrained/PU-EPP_pretrained.pt' 
    # Number CUDA Devices:
    gpu_number = torch.cuda.device_count()
    # Device 
    DEVICE = torch.device('cuda:0')
    protein_dim = 100
    atom_dim = 46
    hid_dim = 128
    norm_shape = 128
    # Batch size
    BATCH_SIZE = 4


In [None]:
if PreCFG.useFasteFile:
    assert PreCFG.fasteFile and PreCFG.compound, 'Please specify the molecular structure of the substrate and the file path to the fasteFile of candidate enzymes'
else:
    assert PreCFG.csvFile, 'Please specify the file path to the .csv file of enzyme-substrate pairse'

In [None]:
def get_model(cfg):
    encoder = Encoder(cfg.protein_dim, cfg.hid_dim, cfg.norm_shape)
    decoder = Decoder(cfg.atom_dim, cfg.hid_dim, cfg.norm_shape)
    model = ModelCat(encoder, decoder)
    model = model.to(cfg.DEVICE)
    if cfg.gpu_number >= 1:
        model = nn.DataParallel(model, device_ids = list(range(cfg.gpu_number)))
    if cfg.state_dict_path is not None:
        if os.path.exists(cfg.state_dict_path):
            model.load_state_dict(torch.load(cfg.state_dict_path, map_location=cfg.DEVICE))
            print('success load state dict')
        else:
            raise ValueError('Wrong path')
    return model

In [None]:
if PreCFG.useFasteFile:
    seq_dict = {rec.id : str(rec.seq) for rec in SeqIO.parse(PreCFG.fasteFile, "fasta")}
    mapp = {str(rec.seq):rec.id for rec in SeqIO.parse(PreCFG.fasteFile, "fasta")}
    seq = list(seq_dict.values())
    com = PreCFG.compound
    df = pd.DataFrame({'smiles':[com]*len(seq), 'seq':seq})
else:
    df = pd.read_csv(PreCFG.csvFile)
df = df[df.Protein.map(lambda x: len(x)<=1500)].reset_index(drop=True)

In [None]:
predata = PredictReader(df, None, PreCFG. word2vec_path)
model = get_model(PreCFG)
pre = Predictor(model)
test_dataloader = DataLoader(predata, batch_size=PreCFG.BATCH_SIZE,shuffle=False, collate_fn=collate_fn)

In [None]:
y, s = pre.predict(test_dataloader)
res = df.copy()
res['y_pre'] = y
res['score'] = s
res = res.drop(index=predata.weong_w2d.index)

In [None]:
if PreCFG.useFasteFile:
    res.insert(0,'id', [mapp[i] for i in res.seq])
    result_name_pre = PreCFG.fasteFile.split('/')[-1][:-6]
else:
    result_name_pre = PreCFG.csvFile.split('/')[-1][:-4]
res.to_csv(f'./results/{result_name_pre}_result.csv', index=False)

In [None]:
# res[['id','seq', 'y_pre', 'score' ]].sort_values(by=['score'], ascending=False)