In [1]:
import os
os.chdir("../..")

In [2]:
import torch, torch.nn as nn, torch.nn.functional as F
import pytorch_lightning as pl
import re, random
from textwrap import wrap
from matplotlib import pyplot as plt
from models.ranked_transformer import HsqcRankedTransformer
from models.ranked_double_transformer import DoubleTransformer
from datasets.hsqc_folder_dataset import FolderDataModule
from utils.ranker import RankingSet

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole, SimilarityMaps

In [3]:
# model path
folder = "/data/smart4.5/pre_exp_v2/"
path = "j1_both_pair_[09_07_2022_06:23]_[bs=128_epochs=300_lr=1e-05_hsqc_heads=8_hsqc_layers=8_ms_heads=8_ms_layers=8_dropout=0.3_fc_dim=256_hsqc_dim_coords=112,112,32_hsqc_dim_model=256_hsqc_dropout=0_hsqc_ff_dim=1024_hsqc_lr=0.001_hsqc_out_dim=6144]"
chkpt = next((f for f in os.listdir(os.path.join(folder, path, "checkpoints")) if re.search("epoch", f)), None)
full_path = os.path.join(folder, path, "checkpoints", chkpt) if chkpt is not None else None
print(full_path)

/data/smart4.5/pre_exp_v2/j1_both_pair_[09_07_2022_06:23]_[bs=128_epochs=300_lr=1e-05_hsqc_heads=8_hsqc_layers=8_ms_heads=8_ms_layers=8_dropout=0.3_fc_dim=256_hsqc_dim_coords=112,112,32_hsqc_dim_model=256_hsqc_dropout=0_hsqc_ff_dim=1024_hsqc_lr=0.001_hsqc_out_dim=6144]/checkpoints/epoch=299-step=23100.ckpt


In [5]:
# load model and data
model = DoubleTransformer.load_from_checkpoint(full_path).cuda()
model.eval()
my_dir = f"/workspace/smart4.5/tempdata/hyun_fp_data/hsqc_ms_pairs"
dm = FolderDataModule(dir=my_dir, do_hyun_fp=True, input_src=["HSQC", "MS"], batch_size=64)
dm.setup("fit")
val_dl = dm.val_dataloader()

In [6]:
ranks = "./tempdata/hyun_pair_ranking_set_07_22/val_pair.pt"
lookup = "./tempdata/hyun_pair_ranking_set_07_22/fp_lookup.pkl"
ranker = RankingSet(file_path=ranks, retrieve_path=lookup)

In [7]:
# extracting samples from the dataset
def extract_sample(idx):
    hsqc, ms, label = dm.train[idx]
    hsqc, ms, label = hsqc.cuda(), ms.cuda(), label.cuda()
    out = model.forward(hsqc.unsqueeze(0), ms.unsqueeze(0))[0]
    ds_out, ds_label = torch.sigmoid(out), label
    ds_out = (ds_out >= 0.5).float()
    return ds_out, ds_label

In [8]:
def molDiff(smiles1, smiles2):
    mol_1 = Chem.MolFromSmiles(smiles1)
    mol_2 = Chem.MolFromSmiles(smiles2)
    return SimilarityMaps.GetSimilarityMapForFingerprint(mol_1, mol_2, SimilarityMaps.GetMorganFingerprint)

In [18]:
def plot_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    im = Chem.Draw.MolToImage(mol, size=(700,600))
    fig, axs = plt.subplots(1, 1, facecolor="white")
    axs.imshow(im)
    axs.set_title("\n".join(wrap(smiles, width=35)))
    display(fig)
def smiles_to_image(smiles):
    mol = Chem.MolFromSmiles(smiles)
    im = Chem.Draw.MolToImage(mol, size=(700,600))
    return im
from notebooks.dataset_building import fingerprint_utils
def smiles_to_fp(smiles):
    return fingerprint_utils.FP_generator(smiles, 2)

In [25]:
def generate_retrieval_image(sample_idx, out_folder):
    ds_out, ds_label = extract_sample(sample_idx)
    nz_out, nz_label = ranker.normalized_to_nonzero(ds_out), ranker.normalized_to_nonzero(ds_label) #
    similarity = F.cosine_similarity(ds_out.unsqueeze(0), ds_label.unsqueeze(0))
    original_smiles = ranker.lookup.get(nz_label, None)
    single = list(original_smiles)[0] # 
    im = smiles_to_image(single) #
    
    results = ranker.retrieve(ds_out)
    rank = ranker.batched_rank(ds_out.unsqueeze(0), ds_label.unsqueeze(0))[0].item() #
    out_imgs = []
    out_smiles = []
    out_fps = []
    for i, v in enumerate(results):
        smiles= list(v)
        images = [smiles_to_image(f) for f in smiles]
        fps = [ranker.normalized_to_nonzero(torch.tensor(smiles_to_fp(f))) for f in smiles]
        out_imgs.append(images)
        out_smiles.append(smiles)
        out_fps.append(fps)
    
    if not os.path.exists(out_folder):
        os.makedirs(out_folder, exist_ok=True)
    out_obj = {
        "out_fp": nz_out,
        "label_fp": nz_label,
        "cossim": similarity.item(),
        "original_smiles": list(original_smiles),
        "single_smiles": single,
        "rank": rank,
        "ranked_cts": [len(v) for v in out_imgs],
        "ranked_smiles": out_smiles,
        "ranked_fps": out_fps
    }
    im.save(os.path.join(out_folder, "img.png"))
    for i, v in enumerate(out_imgs):
        for j, q in enumerate(v):
            q.save(os.path.join(out_folder, f"img_{i}_{j}.png"))
    import json
    with open(os.path.join(out_folder, "data.json"), "w") as f:
        json.dump(out_obj, f)
generate_retrieval_image(2, "/workspace/smart4.5/ignore/out")

{'out_fp': (650, 656, 695, 807, 1019, 1057, 1380, 1652, 1873, 2097, 2139, 2362, 2383, 2411, 2572, 2573, 2793, 3112, 3121, 3138, 3208, 3389, 3413, 3652, 3655, 3748, 3856, 3969, 4032, 4076, 4107, 4283, 4308, 4434, 4539, 4573, 4698, 4993, 5008, 5188, 5213, 5224, 5291, 5332, 5465, 5471, 5495, 5593, 5658, 5756, 5888, 6112), 'label_fp': (650, 656, 695, 807, 1019, 1057, 1380, 1652, 1873, 2086, 2097, 2139, 2193, 2250, 2362, 2381, 2383, 2411, 2572, 2573, 2793, 3112, 3121, 3138, 3167, 3174, 3208, 3389, 3413, 3433, 3652, 3655, 3748, 3856, 3932, 3965, 3969, 4032, 4076, 4308, 4316, 4370, 4390, 4434, 4448, 4484, 4539, 4573, 4656, 4692, 4697, 4698, 4993, 5008, 5053, 5188, 5211, 5213, 5224, 5291, 5332, 5465, 5471, 5593, 5629, 5658, 5732, 5756, 5962, 6071, 6112, 6116), 'cossim': 0.7844645380973816, 'original_smiles': ['CC1OC(Oc2c(-c3ccc(O)c(O)c3)oc3cc(O)cc(O)c3c2=O)C(O)C(OC(=O)c2cc(O)c(O)c(O)c2)C1O'], 'single_smiles': 'CC1OC(Oc2c(-c3ccc(O)c(O)c3)oc3cc(O)cc(O)c3c2=O)C(O)C(OC(=O)c2cc(O)c(O)c(O)c2)C1O', '

In [None]:
%matplotlib inline
# visualizing label and retrieved molecules
ds_out, ds_label = extract_sample(2)
similarity = F.cosine_similarity(ds_out.unsqueeze(0), ds_label.unsqueeze(0))
nz_pred, nz_label = ranker.normalized_to_nonzero(ds_out), ranker.normalized_to_nonzero(ds_label)

original_smiles = ranker.lookup.get(nz_label, None)
single = list(original_smiles)[0]

results = ranker.retrieve(ds_out)
rank = ranker.batched_rank(ds_out.unsqueeze(0), ds_label.unsqueeze(0))[0].item()
print("Cossim of output and label", similarity.item())
print("Ground Truth:", original_smiles)
print("Rank:", rank)

plot_smiles(single)

print("=== Retrieval ===")
for i, v in enumerate(results):
    print(f"*** Rank {i:3d} | {v} ***")
    first = list(v)[0]
    plot_smiles(first)
    
    fig, sim = molDiff(list(original_smiles)[0], first)
    display(fig)
    print("*** ***")