03/01/23: Trying to load a ranked transformer for testing in plotly/dash

In [1]:
import os, torch, torch.nn.functional as F
import pytorch_lightning as pl
import re

from pathlib import Path
from textwrap import wrap
from matplotlib import pyplot as plt
from models.ranked_transformer import HsqcRankedTransformer
from models.ranked_double_transformer import DoubleTransformer
from datasets.generic_index_dataset import GenericIndexedModule
from datasets.dataset_utils import pad
from utils.ranker import RankingSet

In [2]:
# model path
folder = Path("/data/smart4.5/new_split")
experiment = Path("lr_1e-5")
chkpts = list(f for f in os.listdir(folder / experiment / "checkpoints") if re.search("epoch", f))
chkpt = chkpts[-1] if len(chkpts) else None
full_path = os.path.join(folder / experiment / "checkpoints" / chkpt) if chkpt is not None else None
print(full_path)

/data/smart4.5/new_split/lr_1e-5/checkpoints/epoch=0-step=3432.ckpt


In [3]:
# load model and data
model = HsqcRankedTransformer.load_from_checkpoint(full_path).cuda()
model.eval()

HsqcRankedTransformer(
  (enc): CoordinateEncoder()
  (loss): BCEWithLogitsLoss()
  (fc): Linear(in_features=128, out_features=6144, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0, inplace=False)
        (dropout2): Dropout(p=0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_feature

In [6]:
SMILES_dataset_path = "tempdata/SMILES_dataset"
features = ["HSQC", "R2-6144FP", "Chemical", "SMILES"]
feature_handlers = [pad, None, None, None]
dm = GenericIndexedModule(SMILES_dataset_path, features, feature_handlers, 
  batch_size = 64, len_override = 200)
dm.setup("fit")

val_dl = dm.val_dataloader()

In [8]:
for batch in val_dl:
  hsqc, fp, chemical, smiles = batch
  print(smiles[0])
  print(chemical[0])
  print(hsqc[0])
  # hsqc, fp = hsqc.to(model.device), fp.to(model.device)
  # out = model(hsqc)
  # out_fp = torch.where(out > 0, 1, 0)
  break

  sequence = pad_sequence([torch.tensor(v, dtype=torch.float) for v in sequence], batch_first=True)
  sequence = pad_sequence([torch.tensor(v, dtype=torch.float) for v in sequence], batch_first=True)
  sequence = pad_sequence([torch.tensor(v, dtype=torch.float) for v in sequence], batch_first=True)
  sequence = pad_sequence([torch.tensor(v, dtype=torch.float) for v in sequence], batch_first=True)


Br.CN1CCc2cc(O)cc3c2C1Cc1ccc(O)c(O)c1-3
"MLS000860019-01!R()-2,10,11-Trihydroxyaporphine hydrobromide"
tensor([[ 1.1852e+02,  6.6382e+00,  5.2922e+03],
        [ 1.1515e+02,  6.6885e+00,  4.6099e+03],
        [ 1.1366e+02,  6.4267e+00,  3.4196e+03],
        [ 1.1263e+02,  7.4786e+00,  8.7722e+03],
        [ 6.2320e+01,  3.1595e+00,  5.2923e+03],
        [ 5.2830e+01,  3.0983e+00, -8.7723e+03],
        [ 5.2830e+01,  2.5967e+00, -8.7723e+03],
        [ 4.3570e+01,  2.8437e+00,  2.1039e+04],
        [ 3.4300e+01,  3.1175e+00, -5.4442e+03],
        [ 3.4300e+01,  2.3950e+00, -5.4442e+03],
        [ 2.9990e+01,  3.2700e+00, -6.1525e+03],
        [ 2.9990e+01,  2.7900e+00, -5.4498e+03],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.000

In [7]:
# extracting samples from the dataset
def extract_sample(idx):
    hsqc, ms, label = dm.train[idx]
    hsqc, ms, label = hsqc.cuda(), ms.cuda(), label.cuda()
    out = model.forward(hsqc.unsqueeze(0), ms.unsqueeze(0))[0]
    ds_out, ds_label = torch.sigmoid(out), label
    ds_out = (ds_out >= 0.5).float()
    return ds_out, ds_label

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole, SimilarityMaps

In [8]:
def molDiff(smiles1, smiles2):
    mol_1 = Chem.MolFromSmiles(smiles1)
    mol_2 = Chem.MolFromSmiles(smiles2)
    return SimilarityMaps.GetSimilarityMapForFingerprint(mol_1, mol_2, SimilarityMaps.GetMorganFingerprint)

In [18]:
def plot_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    im = Chem.Draw.MolToImage(mol, size=(700,600))
    fig, axs = plt.subplots(1, 1, facecolor="white")
    axs.imshow(im)
    axs.set_title("\n".join(wrap(smiles, width=35)))
    display(fig)
def smiles_to_image(smiles):
    mol = Chem.MolFromSmiles(smiles)
    im = Chem.Draw.MolToImage(mol, size=(700,600))
    return im
from notebooks.dataset_building import fingerprint_utils
def smiles_to_fp(smiles):
    return fingerprint_utils.FP_generator(smiles, 2)

In [26]:
def generate_retrieval_image(sample_idx, out_folder):
    ds_out, ds_label = extract_sample(sample_idx)
    nz_out, nz_label = ranker.normalized_to_nonzero(ds_out), ranker.normalized_to_nonzero(ds_label) #
    similarity = F.cosine_similarity(ds_out.unsqueeze(0), ds_label.unsqueeze(0))
    original_smiles = ranker.lookup.get(nz_label, None)
    single = list(original_smiles)[0] # 
    im = smiles_to_image(single) #
    
    results = ranker.retrieve(ds_out)
    rank = ranker.batched_rank(ds_out.unsqueeze(0), ds_label.unsqueeze(0))[0].item() #
    out_imgs = []
    out_smiles = []
    out_fps = []
    for i, v in enumerate(results):
        smiles= list(v)
        images = [smiles_to_image(f) for f in smiles]
        fps = [ranker.normalized_to_nonzero(torch.tensor(smiles_to_fp(f))) for f in smiles]
        out_imgs.append(images)
        out_smiles.append(smiles)
        out_fps.append(fps)
    
    if not os.path.exists(out_folder):
        os.makedirs(out_folder, exist_ok=True)
    out_obj = {
        "out_fp": nz_out,
        "label_fp": nz_label,
        "cossim": similarity.item(),
        "original_smiles": list(original_smiles),
        "single_smiles": single,
        "rank": rank,
        "ranked_cts": [len(v) for v in out_imgs],
        "ranked_smiles": out_smiles,
        "ranked_fps": out_fps
    }
    im.save(os.path.join(out_folder, "img.png"))
    for i, v in enumerate(out_imgs):
        for j, q in enumerate(v):
            q.save(os.path.join(out_folder, f"img_{i}_{j}.png"))
    import json
    with open(os.path.join(out_folder, "data.json"), "w") as f:
        json.dump(out_obj, f)
generate_retrieval_image(2, "/workspace/smart4.5/ignore/out")

In [None]:
%matplotlib inline
# visualizing label and retrieved molecules
ds_out, ds_label = extract_sample(2)
similarity = F.cosine_similarity(ds_out.unsqueeze(0), ds_label.unsqueeze(0))
nz_pred, nz_label = ranker.normalized_to_nonzero(ds_out), ranker.normalized_to_nonzero(ds_label)

original_smiles = ranker.lookup.get(nz_label, None)
single = list(original_smiles)[0]

results = ranker.retrieve(ds_out)
rank = ranker.batched_rank(ds_out.unsqueeze(0), ds_label.unsqueeze(0))[0].item()
print("Cossim of output and label", similarity.item())
print("Ground Truth:", original_smiles)
print("Rank:", rank)

plot_smiles(single)

print("=== Retrieval ===")
for i, v in enumerate(results):
    print(f"*** Rank {i:3d} | {v} ***")
    first = list(v)[0]
    plot_smiles(first)
    
    fig, sim = molDiff(list(original_smiles)[0], first)
    display(fig)
    print("*** ***")