In [22]:
import numpy as np 
from matplotlib import pyplot as plt
import pickle
from synnet.utils.data_utils import SkeletonSet
import random
random.seed(42)

from synnet.encoding.fingerprints import fp_4096
from synnet.utils.reconstruct_utils import fetch_oracle
from tqdm import tqdm
import pandas as pd
cmap = plt.cm.get_cmap('viridis', 4)

In [1]:


sts = pickle.load(open('/home/msun415/SynTreeNet/results/viz/skeletons-valid.pkl', 'rb'))



18:20:17 rdkit INFO: Enabling RDKit 2023.09.5 jupyter extensions


In [2]:

for st in sts:
    random.shuffle(sts[st])
    sts[st] = sts[st][:5]
sk_set = SkeletonSet().load_skeletons(sts)

In [4]:
sk_set.embed_skeletons()

begin computing similarity matrix


100%|██████████| 625521/625521 [01:04<00:00, 9656.23it/s] 


In [4]:

mol_sks = [(sk.index, sk.tree.nodes[sk.tree_root]['smiles']) for smi in sk_set.lookup for sk in sk_set.lookup[smi]]




18:22:21 faiss.loader INFO: Loading faiss with AVX2 support.
18:22:21 faiss.loader INFO: Successfully loaded faiss with AVX2 support.


In [28]:
lines = open('/home/msun415/polymer_walk/datasets/lohc.txt').readlines()
smis = []
mol_sks = []
scores = []
for l in lines[1:]:
    smis.append(l.split(',')[0])
    mol_sks.append((0, l.split(',')[0]))
    scores.append(float(l.split(',')[2]))

In [12]:
fp_dist = np.zeros((len(mol_sks), len(mol_sks)))
fps = [np.array(fp_4096(mol_sk[1])) for mol_sk in mol_sks]
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):
        fp_dist[i, j] = np.abs(fps[i]-fps[j]).sum()

# max_dist = sk_set.sim.max()
max_fp_dist = fp_dist.max()


100%|██████████| 3039/3039 [01:38<00:00, 30.73it/s]


In [13]:
# lambd = 1.0
# dists = np.zeros((len(mol_sks), len(mol_sks)))
# for i in tqdm(range(len(mol_sks))):
#     for j in range(len(mol_sks)):
#         index1 = mol_sks[i][0]
#         index2 = mol_sks[j][0]
#         dists[i][j] = lambd*sk_set.sim[index1][index2]/max_dist + (1-lambd)*fp_dist[i][j]/max_fp_dist

dists_fp = np.zeros((len(mol_sks), len(mol_sks)))
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):  
        dists_fp[i][j] = fp_dist[i][j]/max_fp_dist

100%|██████████| 3039/3039 [00:04<00:00, 692.29it/s]


In [14]:
from sklearn.manifold import MDS
# ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
# coords = ms.fit_transform(dists)
ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
coords_fp = ms.fit_transform(dists_fp)

In [29]:
scores = np.array(scores)
normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
quantiles = pd.qcut(normalized_scores, 4, labels=False)
fig = plt.Figure()
ax = fig.add_subplot(1,1,1)  
scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=10)
fig.colorbar(scatter, ax=ax, label='Score')
ax.set_title(f'MDS Plot of Fingerprint-Property relationship')
ax.set_xlabel('MDS Dimension 1')
ax.set_ylabel('MDS Dimension 2')
fig.savefig(f"/home/msun415/SynTreeNet/fp-wt-h2.png")

In [21]:


for prop in ['qed','drd2','gsk','jnk']:
    oracle = fetch_oracle(prop)
    scores = [oracle(mol_sk[1]) for mol_sk in tqdm(mol_sks)]    
    fig = plt.Figure()
    ax = fig.add_subplot(1,1,1)
    scores = np.array(scores)
    normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
    quantiles = pd.qcut(normalized_scores, 4, labels=False)
    scatter = ax.scatter(coords[:,0], coords[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=10)
    fig.colorbar(scatter, ax=ax, label='Score')
    ax.set_title(f'MDS Plot of (Skeleton, Fingerprint)-Property ({prop}) relationship')
    ax.set_xlabel('MDS Dimension 1')
    ax.set_ylabel('MDS Dimension 2')
    fig.savefig(f"/home/msun415/SynTreeNet/mds-{prop}-1.0.png")

    # fig = plt.Figure()
    # ax = fig.add_subplot(1,1,1)  
    # scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=10)
    # fig.colorbar(scatter, ax=ax, label='Score')
    # ax.set_title(f'MDS Plot of Fingerprint-Property ({prop}) relationship')
    # ax.set_xlabel('MDS Dimension 1')
    # ax.set_ylabel('MDS Dimension 2')
    # fig.savefig(f"/home/msun415/SynTreeNet/mds-{prop}-fp.png")



 11%|█         | 85/793 [00:00<00:03, 200.84it/s]

100%|██████████| 793/793 [00:03<00:00, 231.90it/s]
Found local copy...
100%|██████████| 793/793 [00:15<00:00, 50.11it/s]
Found local copy...
100%|██████████| 793/793 [01:15<00:00, 10.56it/s]
Found local copy...
100%|██████████| 793/793 [01:20<00:00,  9.89it/s]


In [None]:
rxns = 