In [1]:

import torch
torch.set_printoptions(precision=10)

from models.optional_input_ranked_transformer import OptionalInputRankedTransformer
from datasets.optional_2d_folder_dataset import OptionalInputDataModule
import yaml
torch.set_float32_matmul_precision('medium')
from pathlib import Path

from rdkit import Chem
from rdkit.Chem import Draw
# load model 
from datasets.dataset_utils import specific_radius_mfp_loader, plot_NMR

            


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np 
import random
seed=2
np.random.seed(seed)
random.seed(seed)


In [3]:
# load model 

# Choice 1: flexible model
model_path = Path(f"/home/ad.ucsd.edu/w6xu/model_weights/flexible_models_best_FP/r0_r2_FP_trial_1")
checkpoint_path = model_path / "checkpoints/epoch=35-all_inputs.ckpt"

# Choice  2: all-three nmr model
# model_path = Path(f"/root/MorganFP_prediction/reproduce_previous_works/weird_H_and_tautomer_cleaned/entropy_radius_exps_all_info/R0_to_R4_all_info_trial_1")
# checkpoint_path = model_path / "checkpoints/epoch=22-step=11408.ckpt"



hyperpaerameters_path = model_path / "hparams.yaml"

# checkpoint_path = model_path / "checkpoints/epoch=14-step=43515.ckpt"


with open(hyperpaerameters_path, 'r') as file:
    hparams = yaml.safe_load(file)
    
FP_building_type = hparams['FP_building_type'].split("_")[-1]
only_2d = not hparams['use_oneD_NMR_no_solvent']
print(FP_building_type)
max_radius = int(hparams['FP_choice'].split("_")[-1][1:])
print("max_radius: ", max_radius)
specific_radius_mfp_loader.setup(only_2d=only_2d,FP_building_type=FP_building_type)
specific_radius_mfp_loader.set_max_radius(int(hparams['FP_choice'].split("_")[-1][1:]), only_2d=only_2d)


del hparams['checkpoint_path'] # prevent double defition of checkpoint_path
hparams['use_peak_values'] = False


hparams['skip_ranker'] = True


Normal
max_radius:  2


In [4]:
model = OptionalInputRankedTransformer.load_from_checkpoint(checkpoint_path, **hparams, )


Initialized SignCoordinateEncoder[384] with dims [180, 180, 24] and 2 positional encoders. 24 bits are reserved for encoding the final bit


Using jaccard:  False
HsqcRankedTransformer saving args


In [31]:
model = model.cpu()

In [5]:
import torch.nn.functional as F
import heapq
import os
import pickle

def retrieve_top_k_by_rankingset(data, prediction_query, smiles_and_names , k=30):
    query = F.normalize(prediction_query, dim=1, p=2.0).squeeze()

    results = []
    query_products = (data @ query)
    values, indices = torch.topk(query_products,k=k)
    
    for value, idx in zip(values, indices):
        results.append((value, idx, data[idx].nonzero()))
                
                        
    results.sort(key=lambda x: x[0],reverse=True)
    ret = [(value, smiles_and_names[i], fp) for value, i, fp in results]
    # print(torch.tensor(idx))
    # retrieved_FP = [all_fp[i] for i in idx]
    # print(results[0])
  
    return ret
        

In [8]:
smiles_and_names = pickle.load(open(f'/home/ad.ucsd.edu/w6xu/inference/SMILES_and_Names.pkl', 'rb'))

In [13]:
rankingset_path = f"/home/ad.ucsd.edu/w6xu/inference/max_radius_{max_radius}_stacked_together/FP.pt"

rankingset_data = torch.load(rankingset_path)



In [14]:
def unpack_inputs(inputs):
    for i, vals in enumerate(inputs[0]):
        # if vals is [-1, -1, -1]
        if vals[0]==-1 and vals[1]==-1 and vals[2]==-1:
            hsqc_start=i+1
        elif vals[0]==-2 and vals[1]==-2 and vals[2]==-2:
            hsqc_end=i
        elif vals[0]==-3 and vals[1]==-3 and vals[2]==-3:
            c_nmr_start=i+1
        elif vals[0]==-4 and vals[1]==-4 and vals[2]==-4:
            c_nmr_end=i
        elif vals[0]==-5 and vals[1]==-5 and vals[2]==-5:
            h_nmr_start=i+1
        elif vals[0]==-6 and vals[1]==-6 and vals[2]==-6:
            h_nmr_end=i
            
    hsqc = inputs[0,hsqc_start:hsqc_end]
    c_tensor = inputs[0,c_nmr_start:c_nmr_end,0]
    h_tensor = inputs[0,h_nmr_start:h_nmr_end,0]
    return hsqc, c_tensor, h_tensor

# unpack_inputs(inputs)

# !!!
def compute_cos_sim(fp1, fp2):
    return (fp1 @ fp2) / (torch.norm(fp1) * torch.norm(fp2)).item()

In [15]:
def get_delimeter(delimeter_name):
    match delimeter_name:
        case "HSQC_start":
            return torch.tensor([-1,-1,-1]).float()
        case "HSQC_end":
            return torch.tensor([-2,-2,-2]).float()
        case "C_NMR_start":
            return torch.tensor([-3,-3,-3]).float()
        case "C_NMR_end":
            return torch.tensor([-4,-4,-4]).float()
        case "H_NMR_start":
            return torch.tensor([-5,-5,-5]).float()
        case "H_NMR_end":
            return torch.tensor([-6,-6,-6]).float()
        case "solvent_start":
            return torch.tensor([-7,-7,-7]).float()
        case "solvent_end":
            return torch.tensor([-8,-8,-8]).float()
        case "ms_start":
            return torch.tensor([-12,-12,-12]).float()
        case "ms_end":
            return torch.tensor([-13,-13,-13]).float()
        case _:
            raise Exception(f"unknown {delimeter_name}")

 

In [16]:


"""
Nice !
Now we can play with unkown molecule
"""
hsqc_mode = None
'''for a single model, show top-5'''
def show_topK(inputs, k=5, mode = hsqc_mode, ground_truth_FP=None):
    print("_________________________________________________________")

    
    inputs = inputs.unsqueeze(0)
    pred = model(inputs)
    pred = torch.sigmoid(pred) # sigmoid
    pred_FP = torch.where(pred.squeeze()>0.5, 1, 0)
    # print(pred_FP.nonzero().squeeze().tolist())
    if ground_truth_FP is not None:
        print("Prediction's cosine similarity to ground truth: ", compute_cos_sim(ground_truth_FP, pred_FP.to("cpu").float()))
        print("\n\n")
    topk = retrieve_top_k_by_rankingset(rankingset_data, pred, smiles_and_names, k=k)
       
    
    i=0
    for ite, (value, (smile, name, _, _), retrieved_FP) in enumerate(topk):
        print(f"____________________________retival #{i+1}, retrival cosine similarity to prediction: {value.item()}_____________________________")
        mol = Chem.MolFromSmiles(smile)
        # print("retrived FP", retrieved_FP.squeeze().tolist())
        if ground_truth_FP is not None:
            retrieved_FP_6144 = torch.zeros(6144)
            retrieved_FP_6144[retrieved_FP.cpu()] = 1
            print("Retrival's cosine similarity to ground truth: ", compute_cos_sim(ground_truth_FP, retrieved_FP_6144))
        print(f"SMILES: {smile}")
        print(f"Name {name}")
       
        #check is path file exists
        # print("retrieved path: ",oned_path)

        img = Draw.MolToImage(mol)
        img.show()
        i+=1
        



In [23]:
def build_input(compound_dir, mode = hsqc_mode, include_hsqc = True, include_c_nmr = True, include_h_nmr = True):
    print("\n\n")
    print(compound_dir.split("/")[-1])
    print("\n")
    def load_2d():
        return torch.tensor(np.loadtxt(os.path.join(compound_dir, "HSQC.txt"), delimiter=",")).float()
    def load_1d(nmr):
        vals = np.loadtxt(os.path.join(compound_dir, f"{nmr}.txt"), delimiter=",")
        vals = torch.tensor(np.unique(vals)).float()
        return F.pad(vals.view(-1, 1), (0, 2), "constant", 0)
    
    hsqc = load_2d()
    if hsqc is not None:
        hsqc[:,[0,1]] = hsqc[:,[1,0]]
        if mode == "no_sign":
            hsqc = torch.abs(hsqc)
        elif mode == "flip_sign":
            hsqc[:,2] = -hsqc[:,2]
    c_tensor = load_1d("C")
    h_tensor = load_1d("H")
    input_NMRs = []
    if include_hsqc:
        input_NMRs+= [get_delimeter("HSQC_start"),  hsqc,     get_delimeter("HSQC_end")]
    if include_c_nmr:
        input_NMRs+= [get_delimeter("C_NMR_start"), c_tensor, get_delimeter("C_NMR_end")]
    if include_h_nmr:
        input_NMRs+= [get_delimeter("H_NMR_start"), h_tensor, get_delimeter("H_NMR_end")]
    inputs = torch.vstack(input_NMRs)   
    with open(os.path.join(compound_dir, "mw.txt"), 'r') as file:
        # Read the content of the file
        content = file.read()
        # Convert the content to a float
        mw = float(content)
    mol_weight = torch.tensor([mw,0,0]).float()
    # print(inputs)
    # print(hsqc, c_tensor, h_tensor)
    # plot_NMR(hsqc, c_tensor[:,0], h_tensor[:,0])
    return torch.vstack([inputs, get_delimeter("ms_start"), mol_weight, get_delimeter("ms_end")])

In [17]:
Kavaratamide_A_SMILES = "CCCCCCC[C@H](O)CC(=O)N[C@@H](C(C)C)C(=O)N(C)[C@@H](C)C(=O)O[C@@H](C(C)C)C(=O)N1[C@@H](C(C)C)C(OC)=CC1=O"
Kavaratamide_A_FP = specific_radius_mfp_loader.build_mfp_for_new_SMILES(Kavaratamide_A_SMILES)

newCompoundA_SMILES = "O=C1C(NC2=O)CCC(O)N1C(CC3=CC=CC=C3)C(N(C)C(CC4=CC(Br)=C(O)C=C4)C(NC(C(CC)C)C(OC(C)C(C(NC2CCCC[NH3+])=O)NC(C(NC(C(COS(=O)([O-])=O)OC)=O)CC5=CC=CC=C5)=O)=O)=O)=O"
newCompoundA_FP = specific_radius_mfp_loader.build_mfp_for_new_SMILES(newCompoundA_SMILES)

In [18]:
(newCompoundA_FP.dtype)

torch.float32

In [19]:
# look for the higher cos_sim possible in our dataset
def  retrieve_based_on_groudth(groud_truth_FP, k=5):
    topk = retrieve_top_k_by_rankingset(rankingset_data, groud_truth_FP, smiles_and_names, k=k) 
    
    i=0
    for ite, (value, (smile, name, _, _), retrieved_FP) in enumerate(topk):
        print(f"____________________________retival #{i+1}, cosine similarity: {value.item()}_____________________________")
        mol = Chem.MolFromSmiles(smile)
        print("retrived FP", retrieved_FP.squeeze().tolist())
        print(f"SMILES: {smile}")
        print(f"Name {name}")
        #check is path file exists
        # plot_NMR(hsqc, c_tensor, h_tensor)
        img = Draw.MolToImage(mol)
        img.show()
        i+=1


In [20]:
retrieve_based_on_groudth(Kavaratamide_A_FP.unsqueeze(dim=0), k=5)

____________________________retival #1, cosine similarity: 0.6377196311950684_____________________________
retrived FP [6, 7, 8, 9, 10, 11, 12, 14, 16, 17, 26, 27, 28, 29, 30, 33, 37, 38, 42, 44, 45, 62, 64, 66, 67, 68, 69, 70, 71, 77, 83, 109, 112, 113, 114, 117, 120, 124, 126, 128, 131, 134, 135, 136, 137, 138, 141, 144, 166, 167, 168, 169, 171, 172, 176, 177, 186, 193, 195, 201, 214, 216, 218, 219, 222, 224, 225, 238, 240, 241, 242, 258, 261, 266, 272, 284, 285, 349, 350, 352, 491, 502, 503, 504, 531, 573, 574, 575, 577, 578, 580, 597, 601, 612, 635, 636, 637, 662, 679, 680, 691, 782, 822, 833, 837, 900, 945, 1000, 1045, 1083, 1101, 1184, 1205, 1269, 1312, 1408, 1477, 1523, 1661, 1973, 2015, 2095, 2190, 2228, 2303, 2328, 2379, 2459, 2549, 2876, 2925, 3015, 3210, 3393, 3727, 3736, 3744, 3875, 4009, 4070, 4087, 4122, 4294, 4504, 4627, 4781, 4835, 4930, 5237, 5336, 5667, 5792, 5923]
SMILES: CCCCCCCC(CC=CCCC(=O)N(C)C(C)C(=O)NC(C)C(O)C(C)C(=O)N1C(=O)C=C(OC)C1C(C)C)OC
Name Malyngamide X
_

In [21]:
deepSAT_sintokamide_B_FP = specific_radius_mfp_loader.build_mfp_for_new_SMILES("CCC(=O)N[C@H](C[C@H](C)C(Cl)(Cl)Cl)C(=O)N1[C@H](C(=CC1=O)OC)C[C@H](C)C(Cl)(Cl)Cl")
compute_cos_sim(deepSAT_sintokamide_B_FP, Kavaratamide_A_FP)

tensor(0.6227116585)

In [33]:
compound_dir = "/home/ad.ucsd.edu/w6xu/inference/testing_compounds/Kavaratamide_A"
mode = None#"flip_sign"
inputs = build_input(compound_dir, mode=mode,
                     include_h_nmr=False, include_c_nmr=False
                     )
show_topK(inputs, k=5, mode=mode, ground_truth_FP = Kavaratamide_A_FP)

# show_topK(inputs, k=1)




Kavaratamide_A


_________________________________________________________
Prediction's cosine similarity to ground truth:  tensor(0.6198346019)





____________________________retival #1, retrival cosine similarity to prediction: 0.6539087295532227_____________________________
Retrival's cosine similarity to ground truth:  tensor(0.4937184453)
SMILES: CCCCC(C)CC(C)C(=O)N(C)C(CC(C)C)C(=O)NC(C(=O)N(C)C(C(=O)N1CCCC1C(=O)N1C(=O)CCC1C)C(C)C)C(C)O
Name "2,3-dihydromicrocolin C"
____________________________retival #2, retrival cosine similarity to prediction: 0.6487895250320435_____________________________
Retrival's cosine similarity to ground truth:  tensor(0.4873884618)
SMILES: CCCCC(C)CC(C)C(=O)N(C)C(CC(C)C)C(=O)NC(C(=O)N(C)C(C(=O)N1CC(O)CC1C(=O)N1C(=O)CCC1C)C(C)C)C(C)O
Name "2,3-dihydromicrocolin D"
____________________________retival #3, retrival cosine similarity to prediction: 0.648706316947937_____________________________
Retrival's cosine similarity to ground truth:  tensor(0.5214038491)
SMILES: CCCCC(C)CC(C)C(=O)N(C)C(CC(C)C)C(=O)NC(C(=O)N(C)C(C(=O)N1CCCC1C(=O)N1C(=O)C=CC1C)C(C)C)C(C)O
Name Microcolin C
_______________________

In [34]:
model.device

device(type='cpu')

In [35]:
retrieve_based_on_groudth(newCompoundA_FP.unsqueeze(dim=0), k=3)

____________________________retival #1, cosine similarity: 0.7450761198997498_____________________________
retrived FP [0, 1, 2, 3, 4, 5, 6, 9, 14, 16, 17, 22, 23, 26, 27, 28, 29, 30, 33, 35, 36, 37, 38, 42, 44, 45, 46, 47, 48, 57, 59, 67, 68, 69, 70, 71, 73, 75, 77, 83, 97, 99, 100, 101, 102, 103, 112, 113, 114, 134, 135, 136, 137, 138, 141, 144, 147, 150, 151, 153, 154, 155, 156, 157, 158, 166, 167, 168, 169, 171, 172, 188, 191, 195, 198, 214, 216, 219, 222, 224, 225, 235, 240, 241, 242, 251, 259, 268, 297, 309, 319, 322, 323, 336, 337, 340, 341, 349, 350, 352, 354, 361, 362, 363, 367, 375, 377, 385, 386, 485, 496, 521, 573, 574, 575, 577, 578, 580, 592, 597, 598, 603, 612, 614, 617, 635, 636, 637, 674, 691, 756, 779, 798, 807, 816, 829, 833, 843, 848, 875, 887, 897, 918, 919, 920, 943, 948, 951, 973, 994, 1000, 1015, 1038, 1045, 1050, 1098, 1132, 1143, 1157, 1172, 1305, 1311, 1340, 1362, 1404, 1428, 1439, 1513, 1521, 1527, 1559, 1619, 1650, 1661, 1693, 1701, 1736, 1905, 1921, 1957, 

In [38]:
compound_dir = "/home/ad.ucsd.edu/w6xu/inference/testing_compounds/new_compound_A"
inputs = build_input(compound_dir,
                     mode=mode,
                     include_h_nmr=False, include_c_nmr=False
                     )
show_topK(inputs, k=3, ground_truth_FP=newCompoundA_FP)




new_compound_A


_________________________________________________________
Prediction's cosine similarity to ground truth:  tensor(0.6327524781)



____________________________retival #1, retrival cosine similarity to prediction: 0.7645297050476074_____________________________
Retrival's cosine similarity to ground truth:  tensor(0.6693841815)
SMILES: CCC(C)C1NC(=O)C(Cc2ccccc2)N(C)C(=O)C(C(C)CC)N2C(=O)C(CCC2OC)NC(=O)C(CCCN=C(N)N)NC(=O)C(NC(=O)C(O)CO)C(C)OC1=O
Name Micropeptin MZ859
____________________________retival #2, retrival cosine similarity to prediction: 0.7630496025085449_____________________________
Retrival's cosine similarity to ground truth:  tensor(0.6606438160)
SMILES: CCC(C)C1NC(=O)C(Cc2ccc(O)c(Cl)c2)N(C)C(=O)C(C(C)CC)N2C(=O)C(CCC2OC)NC(=O)C(CCCN=C(N)N)NC(=O)C(NC(=O)C(O)CO)C(C)OC1=O
Name 
____________________________retival #3, retrival cosine similarity to prediction: 0.7478505373001099_____________________________
Retrival's cosine similarity to ground truth:  tens

In [39]:
compound_dir = "/home/ad.ucsd.edu/w6xu/inference/testing_compounds/new_compound_B1"
inputs = build_input(compound_dir,
                     mode=mode,
                     include_h_nmr=False, include_c_nmr=False
                     )
show_topK(inputs, k=3)




new_compound_B1


_________________________________________________________
____________________________retival #1, retrival cosine similarity to prediction: 0.7888862490653992_____________________________
SMILES: CC=C1NC(=O)C(Cc2ccccc2)NC(=O)C(C(C)C)NC(=O)C(C(C)CC)NC(=O)C(NC(=O)C(NC(=O)C(CCCN)NC(=O)C2CCCN2C(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)CCC(C)CC)C(C)C)C(C)O)C(C)C)C(C)C)C(C)CC)C(C)OC(=O)C(C(C)C)NC1=O
Name Elisidepsin
____________________________retival #2, retrival cosine similarity to prediction: 0.7759684920310974_____________________________
SMILES: CC=C1NC(=O)C(Cc2ccccc2)NC(=O)C(C(C)C)NC(=O)C(C(C)CC)NC(=O)C(NC(=O)C(NC(=O)C(CCCCN)NC(=O)C2CCCN2C(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)CCCC(C)C)C(C)C)C(C)O)C(C)C)C(C)C)C(C)CC)C(C)OC(=O)C(C(C)C)NC1=O
Name 
____________________________retival #3, retrival cosine similarity to prediction: 0.7757498025894165_____________________________
SMILES: CCC(C)CCCCCCCCCCCCC(=O)NC(CCC(=O)O)C(=O)NC(CCCN)C(=O)NC1Cc2ccc(cc2)OC(=O)C(C(C)CC)NC(=O)C(Cc2

In [40]:
compound_dir = "/home/ad.ucsd.edu/w6xu/inference/testing_compounds/new_compound_B2"
inputs = build_input(compound_dir,
                     mode=mode,
                     include_h_nmr=False, include_c_nmr=False
                     )
show_topK(inputs, k=5)




new_compound_B2


_________________________________________________________


____________________________retival #1, retrival cosine similarity to prediction: 0.8171883225440979_____________________________
SMILES: CCC(C)C1NC(=O)C(C(C)C)NC(=O)C2CCCN2C(=O)C(Cc2ccccc2)NC(=O)C(NC(=O)C(NC(=O)Cc2ccccc2)C(C)C)C(C)OC(=O)C(C(C)C)NC1=O
Name Xentrivalpeptide N
____________________________retival #2, retrival cosine similarity to prediction: 0.8140150904655457_____________________________
SMILES: CCC(C)C(NC(=O)Cc1ccccc1)C(=O)NC1C(=O)NC(Cc2ccccc2)C(=O)N2CCCC2C(=O)NC(C(C)C)C(=O)NC(C(C)C)C(=O)NC(C(C)C)C(=O)OC1C
Name Xentrivalpeptide I
____________________________retival #3, retrival cosine similarity to prediction: 0.8081856966018677_____________________________
SMILES: CCCCCC(=O)NC(C(=O)NC1C(=O)NC(Cc2ccccc2)C(=O)N2CCCC2C(=O)NC(C(C)C)C(=O)NC(C(C)C)C(=O)NC(C(C)C)C(=O)OC1C)C(C)C
Name Xentrivalpeptide G
____________________________retival #4, retrival cosine similarity to prediction: 0.8060168027877808_____________________________
SMILES: CCCC(=O)NC(C(=O)NC1C(=O)NC(Cc2ccccc2)C(