The purpose of this file is to generate ranking visualizations with plotly/dash

In [1]:
import os
os.chdir("../..")

In [2]:
import torch, torch.nn as nn, torch.nn.functional as F
import pytorch_lightning as pl
import re, random
from textwrap import wrap
from matplotlib import pyplot as plt
from models.ranked_transformer import HsqcRankedTransformer
from models.ranked_double_transformer import DoubleTransformer
from datasets.hsqc_folder_dataset import FolderDataModule
from utils.ranker import RankingSet
import glob

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole, SimilarityMaps

In [3]:
# model path
folder = "/data/smart4.5/pre_exp_v2"
path = "j1_both_pair_*"
chkpt = glob.glob(os.path.join(folder, path, "checkpoints", "epoch*"))
assert(len(chkpt) == 1)
chkpt = chkpt[0]
print(chkpt)

/data/smart4.5/pre_exp_v2/j1_both_pair_[09_07_2022_06:23]_[bs=128_epochs=300_lr=1e-05_hsqc_heads=8_hsqc_layers=8_ms_heads=8_ms_layers=8_dropout=0.3_fc_dim=256_hsqc_dim_coords=112,112,32_hsqc_dim_model=256_hsqc_dropout=0_hsqc_ff_dim=1024_hsqc_lr=0.001_hsqc_out_dim=6144]/checkpoints/epoch=299-step=23100.ckpt


In [4]:
ranks = "/workspace/smart4.5/tempdata/hyun_pair_ranking_set_07_22/val_pair.pt"
lookup = "/workspace/smart4.5/tempdata/hyun_pair_ranking_set_07_22/fp_lookup.pkl"
ranker = RankingSet(file_path=ranks, retrieve_path=lookup)

In [5]:
model = DoubleTransformer.load_from_checkpoint(chkpt).cuda()
model.eval()
my_dir = f"/workspace/smart4.5/tempdata/hyun_fp_data/hsqc_ms_pairs"
dm = FolderDataModule(dir=my_dir, do_hyun_fp=True, input_src=["HSQC", "MS"], batch_size=64)
dm.setup("fit")
val_dl = dm.val_dataloader()

Choose samples to show

In [6]:
num_samples = 10

In [7]:
size_val = len(dm.val)
import random
random.seed(123)
sample_idxs = random.sample(range(size_val), num_samples)
print(sample_idxs)

[107, 548, 178, 834, 545, 220, 78, 776, 1098, 1151]


In [8]:
# extracting samples from the dataset
def extract_sample(idx, src):
    hsqc, ms, label = src[idx]
    hsqc, ms, label = hsqc.cuda(), ms.cuda(), label.cuda()
    out = model.forward(hsqc.unsqueeze(0), ms.unsqueeze(0))[0]
    ds_out, ds_label = torch.sigmoid(out), label
    ds_out = (ds_out >= 0.5).float()
    return ds_out, ds_label
def smiles_to_image(smiles):
    mol = Chem.MolFromSmiles(smiles)
    im = Chem.Draw.MolToImage(mol, size=(700,600))
    return im
from notebooks.dataset_building import fingerprint_utils
def smiles_to_fp(smiles):
    return fingerprint_utils.FP_generator(smiles, 2)

In [9]:
def generate_retrieval_image(sample_idx):
    ds_out, ds_label = extract_sample(sample_idx, dm.val)
    nz_out, nz_label = ranker.normalized_to_nonzero(ds_out), ranker.normalized_to_nonzero(ds_label) # turn to tuple of nonzero indexes
    similarity = F.cosine_similarity(ds_out.unsqueeze(0), ds_label.unsqueeze(0)) # cossine sim
    original_smiles = ranker.lookup.get(nz_label, None)
    single = list(original_smiles)[0] # 
    im = smiles_to_image(single) #
    
    results = ranker.retrieve(ds_out) # retrieves smiles strings
    rank = ranker.batched_rank(ds_out.unsqueeze(0), ds_label.unsqueeze(0))[0].item() #
    out_imgs = []
    out_smiles = []
    out_fps = []
    for i, v in enumerate(results):
        smiles = list(v)
        images = [smiles_to_image(f) for f in smiles]
        fps = [ranker.normalized_to_nonzero(torch.tensor(smiles_to_fp(f))) for f in smiles]
        out_imgs.append(images)
        out_smiles.append(smiles)
        out_fps.append(fps)
    return {
        "original_smiles": list(original_smiles),
        "single_smiles": single,
        "cossim": similarity.item(),
        "out_fp": nz_out,
        "label_fp": nz_label,
        "rank": rank,
        "base_image": im, 
        "imgs": out_imgs,
        "smiles": out_smiles,
        "fps": out_fps,
    }
out_obj = {idx:generate_retrieval_image(idx) for idx in sample_idxs}

In [10]:
# save obj
import pickle, os
out_path = "/workspace/smart4.5/ignore"
file_name = "dump.pkl"
os.makedirs(out_path, exist_ok=True)
with open(os.path.join(out_path, file_name), "wb") as f:
  pickle.dump(out_obj, f)