 # SMILES Enrichment and embedding over PrimeKG using NIM/MOLMIM

In this tutorial, we will explain an example to perform smiles enrichment using NIM/MOLMIM for drug nodes in PrimeKG.

First of all, we need to import necessary libraries as follows:

In [18]:
# Import necessary libraries
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG
from aiagents4pharma.talk2knowledgegraphs.utils import pubchem_utils
from aiagents4pharma.talk2knowledgegraphs.utils.enrichments import pubchem_strings
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings import nim_molmim

# Set the logging level for httpx to WARNING to suppress INFO messages
import logging
logging.getLogger("httpx").setLevel(logging.WARNING)

### Load PrimeKG

To load PrimeKG dataset, we can utilize the `PrimeKG` class from the aiagents4pharma/talk2knowledgegraphs library.

The `PrimeKG` needs to be initialized with the path to the PrimeKG dataset to be stored/loaded from the local directory.

In [19]:
# Define primekg data by providing a local directory where the data is stored
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")

# Invoke a method to load the data
primekg_data.load_data()

# Get primekg_nodes and primekg_edges
primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

Loading nodes of PrimeKG dataset ...
../../../../data/primekg/primekg_nodes.tsv.gz already exists. Loading the data from the local directory.
Loading edges of PrimeKG dataset ...
../../../../data/primekg/primekg_edges.tsv.gz already exists. Loading the data from the local directory.


### Dataclass to store drug data

In [47]:
from dataclasses import dataclass

@dataclass
class DrugData:
    name: str
    drugbank_id: str
    pubchem_cid: str = None
    smiles: str = None
    embed_smiles: list = None

dic_drug_data = {}

Load drug data in PrimeKG into the dic

In [48]:
from tqdm import tqdm
# Iterate over the primekg nodes with node_source as 'DrugBank'
for index, row in tqdm(primekg_nodes[primekg_nodes['node_source'] == 'DrugBank'].iterrows(), total=primekg_nodes[primekg_nodes['node_source'] == 'DrugBank'].shape[0]):
    if row['node_source'] == 'DrugBank' and not row['node_name'].endswith('mab'):
        drug_name = row['node_name']
        drugbank_id = row['node_id']
        drug_data = DrugData(name=drug_name, drugbank_id=drugbank_id)
        dic_drug_data[drugbank_id] = drug_data
# Print the number of drug names mapped to DrugBank IDs
print(f"Number of drugs mapped to DrugBank IDs: {len(dic_drug_data)}")

  0%|          | 0/7957 [00:00<?, ?it/s]

100%|██████████| 7957/7957 [00:00<00:00, 40758.60it/s]

Number of drugs mapped to DrugBank IDs: 7715





### SMILES enrichment over drug data using PubChemPy

Since the drug id in PrimeKG are provided as DrugBank ID, we will convert them into their corresponding PubChemID, and use it to extract their SMILES strings representation.

**NOTE**: Comment the `count` variable if you want to get SMILES representation of all the drugs

In [57]:
enrichment = pubchem_strings.EnrichmentWithPubChem()

# Get their SMILES strings
for count, (drugbank_id, drug_data) in enumerate(dic_drug_data.items()):
    # Get PubChem CID from DrugBank ID using pubchem_utils method
    dic_drug_data[drugbank_id].pubchem_id = pubchem_utils.drugbank_id2pubchem_cid(drugbank_id)
    print (f"DrugBank ID: {drugbank_id}, PubChem CID: {dic_drug_data[drugbank_id].pubchem_id}")
    # Get SMILES from PubChem CID using enrichment method
    if dic_drug_data[drugbank_id].pubchem_id:
        smiles = enrichment.enrich_documents([dic_drug_data[drugbank_id].pubchem_id])
        dic_drug_data[drugbank_id].smiles = smiles[0]
        print (f"DrugBank ID: {drugbank_id}, SMILES: {smiles[0]}")
        # Delete the counter to get all the SMILES
        if count == 2:
            break

DrugBank ID: DB09130, PubChem CID: 23978
DrugBank ID: DB09130, SMILES: [Cu]
DrugBank ID: DB09140, PubChem CID: 977
DrugBank ID: DB09140, SMILES: O=O
DrugBank ID: DB00180, PubChem CID: 82153
DrugBank ID: DB00180, SMILES: C[C@]12C[C@@H]([C@H]3[C@H]([C@@H]1C[C@@H]4[C@]2(OC(O4)(C)C)C(=O)CO)C[C@@H](C5=CC(=O)C=C[C@]35C)F)O


### Embedding SMILES strings using NVIDIA's optimized MOLMIM

We will use the `EmbeddingWithMOLMIM` class to get the embeddings.
This class requires `base_url` value at the time of initialization. You must have NIM/MOLMIM running locally or on a remote machine.

In [None]:
# Load all the SMILES strings into a list
smiles_list = []
for drugbank_id in sorted(dic_drug_data.keys()):
    # Check if the SMILES string is not None
    if dic_drug_data[drugbank_id].smiles:
        smiles_list.append(drug_data.smiles)

# Embed the SMILES strings
# Define the base URL for the embedding service
base_url = "http://localhost:8000/embedding"
embedding = nim_molmim.EmbeddingWithMOLMIM(base_url=base_url)
smiles_embedding = embedding.embed_documents(smiles_list)

counter = 0
for drugbank_id in sorted(dic_drug_data.keys()):
    # Check if the SMILES string is not None
    if dic_drug_data[drugbank_id].smiles:
        dic_drug_data[drugbank_id].embed_smiles = smiles_embedding[counter]
        print (f"DrugBank ID: {drugbank_id}, SMILES: {dic_drug_data[drugbank_id].smiles}, Embedding: {dic_drug_data[drugbank_id].embed_smiles}")
        counter += 1

DrugBank ID: DB00180, SMILES: C[C@]12C[C@@H]([C@H]3[C@H]([C@@H]1C[C@@H]4[C@]2(OC(O4)(C)C)C(=O)CO)C[C@@H](C5=CC(=O)C=C[C@]35C)F)O, Embedding: [-0.1204390674829483, -0.1940891444683075, -0.4339446425437927, -0.2996540367603302, -0.729110836982727, -0.8603639006614685, 0.1087680459022522, -0.7364030480384827, 0.3626856505870819, -1.1047277450561523, 0.2942107021808624, 0.28778666257858276, 0.37707236409187317, 0.44420892000198364, 0.7568902373313904, -0.3640008270740509, 0.1164623573422432, -1.0804017782211304, 0.21842145919799805, -1.0145331621170044, -0.4529482424259186, 0.29504555463790894, 1.0504405498504639, 0.4566451907157898, 0.17165839672088623, -0.09387531131505966, -0.8585189580917358, -0.09521356970071793, -0.2553885281085968, 0.017894838005304337, -0.04478992521762848, 0.15266816318035126, 0.6867157220840454, 0.04769552871584892, -0.4582587480545044, 0.3558366596698761, 0.7061216831207275, -0.10821279883384705, -0.3186778128147125, 0.4906691610813141, 0.6112119555473328, -0.80