In [1]:
import numpy as np 
from matplotlib import pyplot as plt
import pickle
from synnet.utils.data_utils import SkeletonSet
import random
random.seed(42)

from synnet.encoding.fingerprints import fp_4096
from synnet.utils.reconstruct_utils import fetch_oracle
from tqdm import tqdm
import pandas as pd

23:40:00 rdkit INFO: Enabling RDKit 2023.09.5 jupyter extensions


In [5]:


sts = pickle.load(open('/home/msun415/SynTreeNet/results/viz/skeletons-valid.pkl', 'rb'))
all_skeletons = pickle.load(open('/home/msun415/SynTreeNet/results/viz/skeletons-train.pkl','rb'))


In [6]:

sk_set = SkeletonSet().load_skeletons(sts)
sk_set.embed_skeletons()


begin computing similarity matrix


100%|██████████| 625521/625521 [01:03<00:00, 9783.68it/s] 


In [7]:
mol_sks = [(sk.index, sk.tree.nodes[sk.tree_root]['smiles']) for smi in sk_set.lookup for sk in sk_set.lookup[smi]]

k = 100

vis_class_criteria = lambda i:len(all_skeletons[list(all_skeletons)[i]])
top_indices = sorted(range(len(all_skeletons)), key=vis_class_criteria)[-k:]
new_class = dict(zip(top_indices[::-1], range(len(top_indices))))

# for visualizing sk, restrict to only 10 from each class
new_mol_sks = []
count = dict(zip(range(len(top_indices)), [0 for _ in top_indices]))
for mol_sk in mol_sks:
    if mol_sk[0] not in new_class:
        continue
    if new_class[mol_sk[0]] not in count:
        continue
    if count[new_class[mol_sk[0]]] >= 30:
        continue
    new_mol_sks.append((new_class[mol_sk[0]], mol_sk[1]))
    count[new_class[mol_sk[0]]] += 1

mol_sks = new_mol_sks

# random.shuffle(mol_sks)
# mol_sks = mol_sks[:1000]

In [8]:
fp_dist = np.zeros((len(mol_sks), len(mol_sks)))
fps = [np.array(fp_4096(mol_sk[1])) for mol_sk in mol_sks]
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):
        fp_dist[i, j] = np.abs(fps[i]-fps[j]).sum()

max_dist = sk_set.sim.max()
max_fp_dist = fp_dist.max()


100%|██████████| 793/793 [00:13<00:00, 56.76it/s]


In [9]:
lambd = 0.5
dists = np.zeros((len(mol_sks), len(mol_sks)))
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):
        index1 = mol_sks[i][0]
        index2 = mol_sks[j][0]
        dists[i][j] = lambd*sk_set.sim[index1][index2]/max_dist + (1-lambd)*fp_dist[i][j]/max_fp_dist

dists_fp = np.zeros((len(mol_sks), len(mol_sks)))
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):  
        dists_fp[i][j] = fp_dist[i][j]/max_fp_dist

100%|██████████| 793/793 [00:00<00:00, 1111.42it/s]
100%|██████████| 793/793 [00:00<00:00, 2540.18it/s]


In [10]:
from sklearn.manifold import MDS
# ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
# coords = ms.fit_transform(dists)
ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
coords_fp = ms.fit_transform(dists_fp)

In [29]:
scores = np.array(scores)
normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
quantiles = pd.qcut(normalized_scores, 4, labels=False)
fig = plt.Figure()
ax = fig.add_subplot(1,1,1)  
scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=10)
fig.colorbar(scatter, ax=ax, label='Score')
# ax.set_title(f'MDS Plot of Fingerprint-Property relationship')
ax.set_xlabel('MDS Dimension 1')
ax.set_ylabel('MDS Dimension 2')
fig.savefig(f"/home/msun415/SynTreeNet/fp-wt-h2.png")

QED


In [None]:
import pandas as pd
plt.rc('text', usetex=False)
plt.rc('font', family='serif')
plt.rcParams['font.size'] = 12  # Default font size for all text
plt.rcParams['axes.titlesize'] = 20  # Font size for axes titles
plt.rcParams['axes.labelsize'] = 14  # Font size for x and y labels
plt.rcParams['xtick.labelsize'] = 10  # Font size for x-tick labels
plt.rcParams['ytick.labelsize'] = 10  # Font size for y-tick labels
plt.rcParams['legend.fontsize'] = 14  # Font size for legends
cmap = plt.cm.get_cmap('viridis', 4)
# 'drd2','gsk','jnk'

fig = plt.Figure(figsize=(20,10))
fig.suptitle('MDS Plot of Structure-Property Relationships')
for i, prop in enumerate(['qed','drd2','gsk','jnk']):
    oracle = fetch_oracle(prop)
    prop = prop.upper()
    print(prop)
    scores = [oracle(mol_sk[1]) for mol_sk in tqdm(mol_sks)]    
    
    ax = fig.add_subplot(2,4,i+1)
    scores = np.array(scores)
    normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
    quantiles = pd.qcut(normalized_scores, 4, labels=False)
    scatter = ax.scatter(coords[:,0], coords[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=50)
    cbar = fig.colorbar(scatter, ax=ax, label=f'{prop} Quantile', ticks=[0,1,2,3])
    cbar.set_ticklabels(range(4))
    ax.set_title(f'MDS Plot of (Tree+FP)-({prop}) relationship')
    ax.set_xlabel('MDS Dimension 1')
    ax.set_ylabel('MDS Dimension 2')    
    fig.savefig(f"/home/msun415/SynTreeNet/mds-{prop}-0.5.png",bbox_inches='tight')

    fig = plt.Figure()
    ax = fig.add_subplot(1,1,1)  
    scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=50)
    cbar = fig.colorbar(scatter, ax=ax, label=f'{prop} Quantile', ticks=[0,1,2,3])
    cbar.set_ticklabels(range(4))
    ax.set_title(f'MDS Plot of (FP)-({prop}) relationship')
    ax.set_xlabel('MDS Dimension 1')
    ax.set_ylabel('MDS Dimension 2')    
    fig.savefig(f"/home/msun415/SynTreeNet/mds-{prop}-fp.png",bbox_inches='tight')



 17%|█▋        | 134/793 [00:00<00:02, 241.03it/s]