In [None]:
import os
import gc
import glob
import sys
import random
import string
import tqdm
import json
import time
import sqlite3
import warnings
import pandas as pd
import numpy as np

from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit import RDLogger

from SmilesPE.pretokenizer import atomwise_tokenizer
from SmilesPE.pretokenizer import kmer_tokenizer
from SmilesPE.spe2vec import Corpus

from sklearn.metrics import accuracy_score
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import classification_report

from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform

from multiprocessing import Pool

from fastai import *
from fastai.text import *
#from utils import *
import torch

sys.path.append('//')
import supp_utils as su

#torch.cuda.set_device(0) #change to 0 if you only has one GPU
# set gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device,torch.cuda.is_available()

In [None]:
input_file_test = "" # Format cid in each line or cid - name in each line
cluster_distance_file = ""


Number_of_workers = 8
gpu_id = 0
if gpu_id != None:
    device = "cuda:" + str(gpu_id)
else:
    gpu_id = 0

torch.cuda.set_device(device)

spe_token_path = "pretraining_tokens.txt"

tokenization = "SPE"
model_path = "models/"
pretraining_new_wt = "_model_clas"
batch_size = 64

In [None]:
# Ignore warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")
# To remove rdkit warning
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [None]:
with open("protein_function_0.005_renamed.yaml", 'r') as stream:
    protein_function = yaml.safe_load(stream)
    
with open("protein_function_0.005_large_sc_all_500.yaml", 'r') as stream:
    protein_function_large = yaml.safe_load(stream)
    
cluster_distance = pd.read_csv(cluster_distance_file,header=None,names=["cluster1","cluster2","distance"])

In [None]:
input_cids = []
names = []
for entry in open(input_file_test,"r").readlines():
    cid = entry.split()[0]
    input_cids.append(cid)
    try:
        name = entry.split()[1]
        if len(entry.split()) > 2:
            name = ' '.join(entry.split()[1:])
        names.append(name)
    except:
        pass
cid_smiles = su.get_smiles_from_cid(input_cids,type_smiles="isomeric",get_from="SDF",folder_file_name="",save_output=False,remove_sdf=False)
test_df = pd.DataFrame([(cid_smiles[cid],cid) for cid in cid_smiles], columns=['Smiles','CID'])

In [None]:
train_valid_filename = "cid_cluster_0.005_renamed.txt"
def check_presence_of_cid_in_train_valid_set(input_cid_list,train_file,valid_file=None):
    cid_train = [int(entry.split()[0]) for entry in open(train_valid_filename,"r").readlines() if len(entry) > 0]
    already_present = []
    for cid in input_cid_list:
        if cid in cid_train:
            already_present.append(cid)
    return (already_present,cid_train)

already_present_train,cid_train = check_presence_of_cid_in_train_valid_set(list(map(int,(list(cid_smiles.keys())))),train_valid_filename)

In [None]:
def get_accuracy(yhat,y):
    softmax = torch.exp(yhat.float())
    prob = softmax.cpu().detach().numpy()
    predictions = np.argmax(prob, axis=1)
    y_truth = y.cpu().detach().numpy()
    accuracy_check = (y_truth==predictions)
    count = np.count_nonzero(accuracy_check)
    accuracy = (count/len(accuracy_check))
    return accuracy

In [None]:
if tokenization == "SPE":
    MolTokenizer = su.molpmofit.MolTokenizer_spe_sos_eos
    tok = Tokenizer(partial(MolTokenizer,token_path=spe_token_path), n_cpus=Number_of_workers, pre_rules=[], post_rules=[])
else:
    MolTokenizer = su.molpmofit.MolTokenizer_atomwise_sos_eos
    tok = Tokenizer(partial(MolTokenizer), n_cpus=Number_of_workers, pre_rules=[], post_rules=[])

tok = Tokenizer(partial(MolTokenizer,token_path=spe_token_path), n_cpus=Number_of_workers, pre_rules=[], post_rules=[])

In [None]:
vocab = [vocab_token.strip() for vocab_token in open("models/text_class_vocab.txt","r").readlines()]
vocab_class = text.transform.Vocab(vocab)
test_data_clas = TextClasDataBunch.from_df("", test_df, test_df, bs=batch_size, tokenizer=tok, 
                              chunksize=50000, text_cols='Smiles',label_cols='CID', vocab=vocab_class, max_vocab=60000,
                                              include_bos=False,classes=[i for i in range(1,242)])

In [None]:
learner = text_classifier_learner(test_data_clas, AWD_LSTM, pretrained=False, drop_mult=0.2)
learner.load('_model_clas', purge=True);

In [None]:
# Prediction of clusters
cid_prediction = {}

for i,cid in enumerate(cid_smiles):
    smiles = cid_smiles[cid]
    results = learner.predict(smiles)
    prob = results[2].cpu().detach().numpy()
    predictions = results[1].cpu().detach().numpy().tolist()
    if len(names) > 0:
        cid_prediction[cid] = {"prediction":predictions,"name":names[i],"softmax_probability":max(prob)}
    else:
        cid_prediction[cid] = {"prediction":predictions,"softmax_probability":max(prob)}

In [None]:
# Get distance between clusters
predicted_clusters = []
for cid in cid_prediction:
    cluster = cid_prediction[cid]["prediction"]
    predicted_clusters.append(cluster)

cluster_distance_dicts = []
for clust1 in predicted_clusters:
    for clust2 in predicted_clusters:
        if clust1 != clust2:
            distance = cluster_distance[((cluster_distance["cluster1"] == clust1) | (cluster_distance["cluster2"] == clust1)) & \
                             ((cluster_distance["cluster1"] == clust2) | (cluster_distance["cluster2"] == clust2))]["distance"].tolist()[0]
        else:
            distance = 0
        entry1 = {"cluster1":clust1,"cluster2":clust2,"distance":distance}
        entry2 = {"cluster1":clust2,"cluster2":clust1,"distance":distance}
        if entry1 not in cluster_distance_dicts:
            cluster_distance_dicts.append(entry1)
        if entry2 not in cluster_distance_dicts:
            cluster_distance_dicts.append(entry2)

In [None]:
# Get distance between cids
cid_distance_dicts = []
for cid1 in cid_prediction:
    for cid2 in cid_prediction:
        clust1 = cid_prediction[cid1]["prediction"]
        clust2 = cid_prediction[cid2]["prediction"]
        for lists in cluster_distance_dicts:
            try:
                if lists['cluster1'] == clust1 and lists['cluster2'] == clust2:
                    distance = lists["distance"]
                    cid_distance = {"cid1":cid1,"cid2":cid2,"distance":distance}
                    cid_distance_dicts.append(cid_distance)
                    break
            except:
                pass

In [None]:
# Make distance matrix
cid_distance_matrix = []
for cid1 in input_cids:
    row_list = []
    for cid2 in input_cids:
        for lists in cid_distance_dicts:
            try:
                if lists['cid1'] == int(cid1) and lists['cid2'] == int(cid2):
                    distance = lists["distance"]
                    if distance == 0 and cid1 != cid2:
                        distance += 0.001
                    row_list.append(distance)
                    break
            except:
                pass
    cid_distance_matrix.append(row_list)
dm_array = np.array(cid_distance_matrix)

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.rcParams['figure.dpi'] = 150
def make_plot(distance_matrix,labels,figure_name="chemical_distance.png"):
    dists = squareform(distance_matrix)
    links = linkage(dists, "complete")
    ddata = dendrogram(links, labels=labels,leaf_font_size=12,orientation="left") 
    for i, d in zip(ddata['icoord'], ddata['dcoord']):
            y = 0.5 * sum(i[1:3])
            x = d[1]
            if x > 0.001:
                #plt.plot(x, y, 'ro')
                plt.annotate("%.3g" % x, (x, y), xytext=(0, +12),
                                 textcoords='offset points',
                                 va='top', ha='center',fontsize=12)

    plt.xlabel("Distance",fontsize=12)
    plt.ylabel("Chemicals",fontsize=12)
    plt.title("Chemical distance",fontsize=15)#, orientation='left'
    plt.savefig(figure_name)
    plt.show()
    

In [None]:
make_plot(dm_array,input_cids,"chemical_distance_cids.png")

In [None]:
if len(names) > 0:
    make_plot(dm_array,names,"chemical_distance_names.png")

In [None]:
# Club predictions
cluster_cid = {}
for cid in cid_prediction:
    predicted_cluster = cid_prediction[cid]["prediction"]
    if predicted_cluster not in cluster_cid:
        cluster_cid[predicted_cluster] = []
    cluster_cid[predicted_cluster].append(cid)

In [None]:
# Get function
def get_function_from_prediction(protein_function_file):
    for cluster in cluster_cid:
        cids = cluster_cid[cluster]
        print ("CIDs = " + str(cids)[1:-1])
        if len(names) > 0:
            chemical_names = []
            for cid in cids:
                index = input_cids.index(str(cid))
                chemical_names.append(names[index])
            print ("Chemical names = " + str(chemical_names)[1:-1])
        print ("Cluster predicted= " + str(cluster))
        print ("\nFunction\n")
        for i,entry in enumerate(protein_function_file[cluster]):
            print (str(i+1) + ".) " + str(entry)[1:-1] + "\n")
        print ("\n\n\n")

In [None]:
cid_prediction

In [None]:
get_function_from_prediction(protein_function)

In [None]:
get_function_from_prediction(protein_function_large)