In [93]:
import numpy as np 
from matplotlib import pyplot as plt
import pickle
from synnet.utils.data_utils import SkeletonSet
import random
random.seed(42)

from synnet.encoding.fingerprints import fp_4096
from synnet.utils.reconstruct_utils import fetch_oracle
from tqdm import tqdm
import pandas as pd

In [39]:


sts = pickle.load(open('/home/msun415/SynTreeNet/results/viz/skeletons-valid.pkl', 'rb'))
all_skeletons = pickle.load(open('/home/msun415/SynTreeNet/results/viz/skeletons-train.pkl','rb'))


In [23]:

sk_set = SkeletonSet().load_skeletons(sts)
sk_set.embed_skeletons()


begin computing similarity matrix


100%|██████████| 625521/625521 [00:39<00:00, 15970.36it/s]


KeyError: 0

In [187]:
mol_sks = [(sk.index, sk.tree.nodes[sk.tree_root]['smiles']) for smi in sk_set.lookup for sk in sk_set.lookup[smi]]

k = 100

vis_class_criteria = lambda i:len(all_skeletons[list(all_skeletons)[i]])
top_indices = sorted(range(len(all_skeletons)), key=vis_class_criteria)[-k:]
new_class = dict(zip(top_indices[::-1], range(len(top_indices))))

# for visualizing sk, restrict to only 10 from each class
new_mol_sks = []
count = dict(zip(range(len(top_indices)), [0 for _ in top_indices]))
for mol_sk in mol_sks:
    if mol_sk[0] not in new_class:
        continue
    if new_class[mol_sk[0]] not in count:
        continue
    if count[new_class[mol_sk[0]]] >= 30:
        continue
    new_mol_sks.append((new_class[mol_sk[0]], mol_sk[1]))
    count[new_class[mol_sk[0]]] += 1

mol_sks = new_mol_sks

# random.shuffle(mol_sks)
# mol_sks = mol_sks[:1000]

In [188]:
# lines = open('/home/msun415/polymer_walk/datasets/lohc.txt').readlines()
# smis = []
# mol_sks = []
# scores = []
# for l in lines[1:]:
#     smis.append(l.split(',')[0])
#     mol_sks.append((0, l.split(',')[0]))
#     scores.append(float(l.split(',')[2]))

In [189]:
fp_dist = np.zeros((len(mol_sks), len(mol_sks)))
fps = [np.array(fp_4096(mol_sk[1])) for mol_sk in mol_sks]
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):
        fp_dist[i, j] = np.abs(fps[i]-fps[j]).sum()

max_dist = sk_set.sim.max()
max_fp_dist = fp_dist.max()


100%|██████████| 1314/1314 [00:11<00:00, 113.58it/s]


In [190]:
lambd = 0.5
dists = np.zeros((len(mol_sks), len(mol_sks)))
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):
        index1 = mol_sks[i][0]
        index2 = mol_sks[j][0]
        dists[i][j] = lambd*sk_set.sim[index1][index2]/max_dist + (1-lambd)*fp_dist[i][j]/max_fp_dist

dists_fp = np.zeros((len(mol_sks), len(mol_sks)))
for i in tqdm(range(len(mol_sks))):
    for j in range(len(mol_sks)):  
        dists_fp[i][j] = fp_dist[i][j]/max_fp_dist

100%|██████████| 1314/1314 [00:01<00:00, 690.76it/s]
100%|██████████| 1314/1314 [00:00<00:00, 1533.41it/s]


In [191]:
from sklearn.manifold import MDS
# ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
# coords = ms.fit_transform(dists)
ms = MDS(n_components=2, dissimilarity='precomputed', verbose=1)
coords_fp = ms.fit_transform(dists_fp)

In [192]:
plt.rc('text', usetex=False)
plt.rc('font', family='serif')
plt.rcParams['font.size'] = 12  # Default font size for all text
plt.rcParams['axes.titlesize'] = 20  # Font size for axes titles
plt.rcParams['axes.labelsize'] = 14  # Font size for x and y labels
plt.rcParams['xtick.labelsize'] = 10  # Font size for x-tick labels
plt.rcParams['ytick.labelsize'] = 10  # Font size for y-tick labels
plt.rcParams['legend.fontsize'] = 14  # Font size for legends
cmap = plt.cm.get_cmap('viridis', k)
fig = plt.Figure()
ax = fig.add_subplot(1,1,1)  
scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=[mol_sk[0] for mol_sk in mol_sks], cmap=cmap, edgecolor='k', s=50)
cbar = fig.colorbar(scatter, ax=ax, label='Class', ticks=(range(0,100,10) if k == 100 else range(k)))
cbar.set_ticklabels([i for i in (range(0,100,10) if k == 100 else range(k))])
# ax.set_title(f'MDS Plot of FP-Skeleton relationship')
ax.set_xlabel('MDS Dimension 1')
ax.set_ylabel('MDS Dimension 2')
fig.savefig(f"/home/msun415/SynTreeNet/mds-sks-{k}.png",bbox_inches='tight')


In [193]:
scores = np.array(scores)
normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
quantiles = pd.qcut(normalized_scores, 4, labels=False)
fig = plt.Figure()
ax = fig.add_subplot(1,1,1)  
scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=10)
fig.colorbar(scatter, ax=ax, label='Score')
# ax.set_title(f'MDS Plot of Fingerprint-Property relationship')
ax.set_xlabel('MDS Dimension 1')
ax.set_ylabel('MDS Dimension 2')
fig.savefig(f"/home/msun415/SynTreeNet/fp-wt-h2.png")

NameError: name 'scores' is not defined

In [None]:

fig = plt.Figure(figsize=(20,10))
fig.suptitle('MDS Plot of Structure-Property Relationships')
for i, prop in enumerate(['qed','drd2','gsk','jnk']):
    oracle = fetch_oracle(prop)
    scores = [oracle(mol_sk[1]) for mol_sk in tqdm(mol_sks)]    
    
    ax = fig.add_subplot(2,4,i+1)
    scores = np.array(scores)
    normalized_scores = (scores - min(scores)) / (max(scores) - min(scores))
    quantiles = pd.qcut(normalized_scores, 4, labels=False)
    scatter = ax.scatter(coords[:,0], coords[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=20)
    fig.colorbar(scatter, ax=ax, label='Score')
    ax.set_title(f'(Skeleton, Fingerprint)-Property ({prop})', fontsize=10)
    ax.set_xlabel('MDS Dimension 1')
    ax.set_ylabel('MDS Dimension 2')
    # fig.savefig(f"/home/msun415/SynTreeNet/mds-{prop}-0.5.png")

    ax = fig.add_subplot(2,4,4+i+1)  
    scatter = ax.scatter(coords_fp[:,0], coords_fp[:,1], c=quantiles, cmap=cmap, edgecolor='k', s=20)
    fig.colorbar(scatter, ax=ax, label='Score')
    ax.set_title(f'Fingerprint-Property ({prop})', fontsize=10)
    ax.set_xlabel('MDS Dimension 1')
    ax.set_ylabel('MDS Dimension 2')
    fig.savefig(f"/home/msun415/SynTreeNet/mds-all.png")



100%|██████████| 1000/1000 [00:01<00:00, 615.45it/s]
Found local copy...
100%|██████████| 1000/1000 [00:10<00:00, 92.01it/s]
Found local copy...
100%|██████████| 1000/1000 [00:57<00:00, 17.33it/s]
Found local copy...
100%|██████████| 1000/1000 [01:06<00:00, 15.10it/s]


In [15]:
scores

array([0.  , 0.01, 0.02, ..., 0.  , 0.  , 0.  ])