In [None]:
import os
import gc
import glob
import sys
import random
from random import sample
import string
import tqdm
import json
import pandas as pd
import numpy as np

from multiprocessing import Pool
from functools import partial

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim

from rdkit import Chem

import sklearn
from sklearn.metrics import accuracy_score
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import classification_report

import codecs
from SmilesPE.pretokenizer import atomwise_tokenizer
from SmilesPE.pretokenizer import kmer_tokenizer
from SmilesPE.learner import *
from SmilesPE.tokenizer import *

import matplotlib.pyplot as plt

# Supp script path
sys.path.append('')
import supp_utils as su

# set gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device,torch.cuda.is_available()

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")

In [None]:
distance_threshold = 0.005
cid_smiles_file = "cid_smiles_sanitized_canonical.txt" # CID-SMILES information for the datapoints obtained from quantmap data

cid_cluster_filename = "cid_cluster_" + str(distance_threshold) + ".txt"
node_details_filename = "clustering_details_" + str(distance_threshold) + ".csv"
cid_order_list = list(map(lambda x:int(x),open("cid_order_file.txt","r").readlines()))
cluster_distance_filename = "cluster_distance_" + str(distance_threshold) + ".csv"
subset_folder = "subset_data/"
support_above_folder = "cluster_support_above_100/"
cluster_distance_output_filename = support_above_folder + "cluster_distance_" + str(distance_threshold) + ".csv"

In [None]:
def cid_cluster_file(input_filename):
    with open(input_filename,"r") as f:
        cid_cluster = {}
        for entry in f.readlines():
            cid_cluster[int(entry.split()[0])] = int(entry.split()[1])
    
    return (cid_cluster)

def read_cluster_distance_file(input_filename):
    cluster_distance = []
    with open(input_filename) as f:
        for i,entry in enumerate(f.readlines()):
            if i != 0:
                entry_split = entry.split(",")
                dicts = {"cluster1":int(entry_split[0]),"cluster2":int(entry_split[1]),"distance":float(entry_split[2])}
                cluster_distance.append(dicts)
    
    return (cluster_distance)

cid_cluster_all = cid_cluster_file(cid_cluster_filename)
cluster_distance = read_cluster_distance_file(cluster_distance_filename)

In [None]:
# choose cids with support more than specified
cluster_lower_threshold = 100
cluster_cids_all = {cluster_id:[] for cluster_id in list(map(int,set(list(cid_cluster_all.values()))))}
for cid in cid_cluster_all:
    cluster_cids_all[cid_cluster_all[cid]].append(cid)
    
cluster_cids = {}
for cluster in cluster_cids_all:
    if len(cluster_cids_all[cluster]) >= cluster_lower_threshold:
        cluster_cids[cluster] = cluster_cids_all[cluster]
        
cid_cluster = {}
for cid in cid_cluster_all:
    if cid_cluster_all[cid] in cluster_cids:
        cid_cluster[cid] = cid_cluster_all[cid]

In [None]:
def reassign_cluster_id(cid_cluster_dict,cluster_distance=None,output_path="",addon=""):
    # Getting and reassigning cid_cluster dict and writing the output
    all_clusters = sorted(list(set(list(cid_cluster_dict.values()))))
    
    reassign_dict = {}
    for i,cluster in enumerate(all_clusters):
        reassign_dict[cluster] = i
    
    reassigned_cid_cluster_dict = {}
    for cid in cid_cluster_dict:
        reassigned_cid_cluster_dict[cid] = reassign_dict[cid_cluster_dict[cid]]
    
    with open(output_path + cid_cluster_filename[:-4] + str(addon) + ".txt","w") as of:
        for cid in reassigned_cid_cluster_dict:
            of.write(str(cid) + " " + str(reassigned_cid_cluster_dict[cid]) + "\n")
    
    
    if not cluster_distance is None:
        # Getting and reassigning cluster_distance dict and writing the output
        with open(output_path + cluster_distance_filename[:-4] + str(addon) + ".csv","w") as of:
            for dicts in cluster_distance:
                cluster1 = str(reassign_dict[dicts["cluster1"]])
                cluster2 = str(reassign_dict[dicts["cluster2"]])
                distance = str(dicts["distance"])
                of.write(cluster1 + "," + cluster2 + "," + distance + "\n")

reassign_cluster_id(cid_cluster,cluster_distance,support_above_folder)

In [None]:
cid_cluster = cid_cluster_file(support_above_folder + cid_cluster_filename)
cluster_distance = read_cluster_distance_file(support_above_folder + cluster_distance_filename)

In [None]:
def write_smiles_cluster_file(cid_smiles_file,cid_cluster_dict,output_filename,path):
    cid_smiles = {int(entry.split()[0]):entry.split()[1] for entry in open(cid_smiles_file,"r").readlines()}
    
    non_found_cid_smiles = []
    
    with open(path + output_filename,"w") as of:
        for cid in cid_cluster_dict:
            try:
                of.write(str(cid_smiles[cid]) + " " + str(cid_cluster_dict[cid]) + "\n")
            except:
                non_found_cid_smiles.append(cid)
    return non_found_cid_smiles

In [None]:
non_found_cid_smiles = write_smiles_cluster_file(cid_smiles_file,cid_cluster,"smiles_cluster_sanitized_" + str(distance_threshold) + ".txt",support_above_folder)

### Get distance between clusters

In [None]:
with open(node_details_filename,"r") as f:
    node_details = [] 
    for i,entry in enumerate(f.readlines()):
        if i != 0:
            x = entry.split(",")
            node_details.append({'node_id': int(x[0]), 'left': int(x[1]), 'right': int(x[2]), 'distance' : float(x[3])})
            
            
cluster_count = list(cid_cluster_all.values())

In [None]:
# Cluster frequency distribution
x = cluster_count
plt.hist(x, density=False, bins=100)  # density=False would make counts
plt.ylabel('Frequency')
plt.xlabel('Cluster number');

In [None]:
def find_cid_positions_in_node(node):
    
    if node in node_cids:
        return node_cids[node]
        
    cids = []
    for dicts in node_details:
        if dicts["node_id"] == node:
            left_node = dicts["left"]
            right_node = dicts["right"]
            not_found = False
            break
        not_found = True
        
    if left_node < total_samples:
        cid = cid_order_list[left_node]
        if cid in cid_cluster:
            cids.append(cid)
    else:
        if left_node not in node_cids:
            found_cids = find_cid_positions_in_node(left_node)
            cids.extend(found_cids)
            node_cids[left_node] = found_cids
        else:
            cids.extend(node_cids[left_node])
        
    if right_node < total_samples:
        cid = cid_order_list[right_node]
        if cid in cid_cluster:
            cids.append(cid)
    else:
        if right_node not in node_cids:
            found_cids =  find_cid_positions_in_node(right_node)
            cids.extend(found_cids)
            node_cids[right_node] = found_cids
        else:
            cids.extend(node_cids[right_node])
    
    output_cids = []
    already_present_cluster = []
    for cid in cids: 
        if cid_cluster[cid] not in already_present_cluster:
            output_cids.append(cid)
            already_present_cluster.append(cid_cluster[cid])

    return (output_cids)

def get_cluster_distance_two_cids(left_cids,right_cids,distance):
    for left_cid in left_cids:
        for right_cid in right_cids:
            if str(left_cid) + "_" + str(right_cid) not in already_found_cids and str(right_cid) + "_" + str(left_cid) not in already_found_cids:    
                left_cluster, right_cluster = cid_cluster[left_cid],cid_cluster[right_cid]
                if str(left_cluster) + "_" + str(right_cluster) not in already_found_clusters and \
                str(right_cluster) + "_" + str(left_cluster) not in already_found_clusters:
                    cluster_distance.append({"cluster1":left_cluster,"cluster2":right_cluster,"distance":distance})
                    already_found_clusters.append(str(left_cluster) + "_" + str(right_cluster))
                    already_found_cids.append(str(left_cid) + "_" + str(right_cid))


In [None]:
cid_order = {}
for cid in cid_order_list:
    if cid in cid_cluster:
        cid_order[cid] = True
    else:
        cid_order[cid] = False

In [None]:
total_clusters = len(cluster_cids_all)
total_samples = len(cid_cluster_all)

all_nodes = []
for dicts in node_details:
    all_nodes.append(dicts["node_id"])
    
node_cids = {}
loop = tqdm.tqdm(all_nodes,total=len(all_nodes),leave=False)
for node in loop:
    if node not in node_cids:
        node_cids[node] = find_cid_positions_in_node(node)

In [None]:
# Clusters with one support
one_list = []
for lists in (list(node_cids.values())):
    one_list.extend(lists)
    
one_cluster = []
for cid in set(one_list):
    one_cluster.append(cid_cluster[cid])

print (len(set(one_cluster)),len(cid_cluster),len(set(one_list)))

In [None]:
already_found_clusters = []
cluster_distance = []
already_found_cids = []
node_details_inverted = node_details[::-1]

loop = tqdm.tqdm(node_details_inverted,total=len(node_details_inverted),leave=False)
for dicts in loop:
    distance = dicts["distance"]
    if distance > distance_threshold:
        left_node = dicts["left"]
        right_node = dicts["right"]
        
        if left_node > total_samples:
            left_cids = node_cids[left_node]
            left_condition = True
        else:
            left_cids = [cid_order_list[left_node]]
            left_condition = cid_order[left_cids[0]]
            
        if right_node > total_samples:
            right_cids = node_cids[right_node]
            right_condition = True
        else:
            right_cids = [cid_order_list[right_node]]
            right_condition = cid_order[right_cids[0]]
            
        if left_condition and right_condition and len(left_cids) > 0 and len(right_cids) > 0:
            output_clusters = get_cluster_distance_two_cids(left_cids,right_cids,distance)

In [None]:
outfile = open(cluster_distance_output_filename,"w")
outfile.write("cluster1,cluster2,distance\n")
for dicts in cluster_distance:
    outfile.write(str(dicts["cluster1"]) + "," + str(dicts["cluster2"]) + "," + str(dicts["distance"]) + "\n")

### Selection of clusters for pilot runs (subset data)

In [None]:
first_five_clusters = []
distance_list = []
cluster_of_interst = 1
for entry in cluster_distance:
    if entry["cluster1"] == cluster_of_interst or entry["cluster2"] == cluster_of_interst:
        distance_list.append(entry["distance"])

distance_list = sorted(distance_list)[:5]
count = 0
for entry in cluster_distance:
    if entry["cluster1"] == cluster_of_interst or entry["cluster2"] == cluster_of_interst:
        if entry["distance"] in distance_list and count <= 3:
            first_five_clusters.append(entry["cluster1"])
            first_five_clusters.append(entry["cluster2"])
            print (entry["distance"])
            count += 1
first_five_clusters = sorted(list(set(first_five_clusters)))
print (first_five_clusters)

In [None]:
second_lower_distane_clusters = []
  
cluster_lower_distance_list = sorted([float(entry["distance"]) for entry in cluster_distance])[:5]

for entry in cluster_distance:
    if entry["distance"] in cluster_lower_distance_list and len(set(second_lower_distane_clusters)) < 5:
        second_lower_distane_clusters.append(entry["cluster1"])
        second_lower_distane_clusters.append(entry["cluster2"])
        print (entry["distance"] )
second_lower_distane_clusters = sorted(list(set(second_lower_distane_clusters)))[:5]
print (second_lower_distane_clusters)

In [None]:
def distance_between_two_clusters(input_cluster1,input_cluster2):
    for entry in cluster_distance:
        if (entry["cluster1"] == input_cluster1 or entry["cluster2"] == input_cluster1) and \
            (entry["cluster1"] == input_cluster2 or entry["cluster2"] == input_cluster2):
            print (entry["distance"])
    
def farthest_cluster(input_cluster):
    largest_distance = 0
    for entry in cluster_distance:
        if entry["cluster1"] == input_cluster or entry["cluster2"] == input_cluster:
            if entry["distance"] > largest_distance:
                largest_distance = entry["distance"]
    
    output_clusters = []
    for entry in cluster_distance:
        if (entry["cluster1"] == input_cluster or entry["cluster2"] == input_cluster) and entry["distance"] == largest_distance:
            if entry["cluster1"] != input_cluster:
                output_clusters.append(entry["cluster1"])
            if entry["cluster2"] != input_cluster:
                output_clusters.append(entry["cluster2"])
    return (output_clusters)

In [None]:
farthest_clusters_to_first_two_sets = []
for entry in first_five_clusters + second_lower_distane_clusters:
    farthest_clusters_to_first_two_sets.extend(farthest_cluster(entry))
    
farthest_clusters_to_first_two_sets = sorted(list(set(farthest_clusters_to_first_two_sets)))[:5]

In [None]:
random_cluster_selection = []
selected_cluster_from_three = list(set(first_five_clusters + second_lower_distane_clusters + farthest_clusters_to_first_two_sets))
remaining_count = 20 - len(selected_cluster_from_three)

while remaining_count > 0:
    sampled_cluster = sample(list(set(list(cid_cluster.values()))),1)[0]
    if sampled_cluster not in selected_cluster_from_three and sampled_cluster not in random_cluster_selection:
        random_cluster_selection.append(sampled_cluster)
        remaining_count -= 1
print ("Randomly chosen " + str(20 - (len(first_five_clusters) + len(second_lower_distane_clusters) + len(farthest_clusters_to_first_two_sets))) + " cluster")

In [None]:
len(set(selected_cluster_from_three)),len(selected_cluster_from_three)

In [None]:
selected_clusters = random_cluster_selection + selected_cluster_from_three

In [None]:
len(set(selected_clusters))

In [None]:
with open(subset_folder + "/" + cid_cluster_filename[:-4] + "_selected_clusters.txt","w") as of:
    for cluster in selected_clusters:
        of.write(str(cluster) + "\n")

In [None]:
print (selected_clusters)

In [None]:
subset_cid_cluster = {}
for cid in cid_cluster:
    if cid_cluster[cid] in selected_clusters:
        cluster = selected_clusters[selected_clusters.index(cid_cluster[cid])]
        subset_cid_cluster[cid] = cluster

In [None]:
len(subset_cid_cluster)

In [None]:
reassign_cluster_id(subset_cid_cluster,cluster_distance=None,output_path=subset_folder,addon="_subset")

In [None]:
len(subset_cid_cluster)

In [None]:
subset_cid_cluster = cid_cluster_file(subset_folder + cid_cluster_filename[:-4] + "_subset.txt")

In [None]:
non_found_cid_smiles = write_smiles_cluster_file(cid_smiles_file,subset_cid_cluster,"smiles_cluster_subset_sanitized_" + str(distance_threshold) + ".txt",subset_folder)