## Imports

In [1]:
!pip install pinecone-client datasets PyTDC rdkit datamol pandas numpy molfeat




In [2]:
import tqdm
import pandas as pd
import numpy as np
import itertools
from tdc.generation import MolGen
from pinecone import Pinecone, ServerlessSpec
import datamol as dm
import molfeat
from molfeat.calc import FPCalculator, RDKitDescriptors2D
from molfeat.trans import MoleculeTransformer
from molfeat.store.modelstore import ModelStore
from molfeat.trans.pretrained import PretrainedMolTransformer, GraphormerTransformer

Failed to find the pandas get_adjustment() function to patch
Failed to patch pandas - PandasTools will have limited functionality


## Load data

In [3]:
from tdc.generation import MolGen
data = MolGen(name = 'ZINC')
split = data.get_split() 

Found local copy...
Loading...
Done!


In [4]:
mols = dm.convert.from_df(split['train'])

## Featurize

In [5]:
# df = dm.descriptors.batch_compute_many_descriptors(mols, properties_fn=None, add_properties=True, n_jobs=1, batch_size=None, progress=False, progress_leave=True)

In [6]:
# # Load some dummy data
# data = dm.data.freesolv().sample(100).smiles.values

# # Featurize a single molecule
# calc = FPCalculator("ecfp")
# calc(data[0])

# # Define a parallelized featurization pipeline
# mol_transf = MoleculeTransformer(calc, n_jobs=-1)
# mol_transf(data)

# # Easily save and load featurizers
# mol_transf.to_state_yaml_file("state_dict.yml")
# mol_transf = MoleculeTransformer.from_state_yaml_file("state_dict.yml")
# mol_transf(data)

# # List all available featurizers
# store = ModelStore()
# store.available_models

# # Find a featurizer and learn how to use it
# model_card = store.search(name="ChemBERTa-77M-MLM")[0]
# model_card.usage()

In [6]:
# # List all available featurizers
store = ModelStore()
store.available_models
model_card = store.search(name='pcqm4mv2_graphormer_base')[0]

In [7]:
store.available_models

[ModelInfo(name='cats2d', inputs='smiles', type='hashed', version=0, group='all', submitter='Datamol', description='2D version of the 6 Potential Pharmacophore Points CATS (Chemically Advanced Template Search) pharmacophore. This version differs from `pharm2D-cats` on the process to make the descriptors fuzzy, which is closer to the original paper implementation. Implementation is based on work by Rajarshi Guha (08/26/07) and Chris Arthur (1/11/2015)', representation='vector', require_3D=False, tags=['CATS', 'hashed', '2D', 'pharmacophore', 'search'], authors=['Michael Reutlinger', 'Christian P Koch', 'Daniel Reker', 'Nickolay Todoroff', 'Petra Schneider', 'Tiago Rodrigues', 'Gisbert Schneider', 'Rajarshi Guha', 'Chris Arthur'], reference='https://doi.org/10.1021/ci050413p', created_at=datetime.datetime(2023, 7, 20, 9, 40, 19, 315784), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None),
 ModelInfo(name='cats3d', inputs='mol', type='hashed', 

In [8]:
# featurizer = GraphormerTransformer(kind='pcqm4mv2_graphormer_base', dtype=np.float32, pooling='mean', max_length=None, concat_layers=-1, ignore_padding=True, version=None)
featurizer = PretrainedMolTransformer(kind='ChemGPT-1.2B', dtype=np.float32, pooling='mean', max_length=None, concat_layers=-1, ignore_padding=True, version=None)

In [9]:
featurizer

In [14]:
data = dm.data.freesolv().smiles.values

In [16]:
len(data)

642

In [31]:
len(mols)

174618

In [22]:
calc = RDKitDescriptors2D(replace_nan=True)

In [33]:
featurizer = MoleculeTransformer(calc, dtype=np.float32)

with dm.without_rdkit_log():
    feats = np.stack(featurizer(mols[:100]))

feats.dtype

dtype('float32')

In [34]:
len(feats[30])

216

In [35]:
feats

array([[1.36326046e+01, 1.36326046e+01, 1.20596364e-02, ...,
        1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.30737009e+01, 1.30737009e+01, 1.83970407e-02, ...,
        1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.06977310e+01, 1.06977310e+01, 1.17962964e-01, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       ...,
       [1.27002373e+01, 1.27002373e+01, 1.83461681e-02, ...,
        3.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.27412386e+01, 1.27412386e+01, 1.88195139e-01, ...,
        2.00000000e+00, 0.00000000e+00, 1.00000000e+00],
       [6.07339907e+00, 6.07339907e+00, 6.93189323e-01, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], dtype=float32)

In [36]:
id_list = [f'Molecule {i}' for i, _ in enumerate(feats)]

In [42]:
payload = zip(id_list, feats)

## Configure Pinecone index

In [10]:
pc = Pinecone(api_key="c2c9ba1d-9710-472b-a950-a3db5b40a67c")

In [45]:
index = pc.Index('molsearch')

In [None]:
## Upsert vectors

In [49]:
def chunks(iterable, batch_size=100):
    """A helper function to break an iterable into chunks of size batch_size."""
    it = iter(iterable)
    chunk = tuple(itertools.islice(it, batch_size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it, batch_size))

In [50]:
for ids_vectors_chunk in chunks(payload, batch_size=10):
    index.upsert(vectors=ids_vectors_chunk) 

## Similarity search

In [None]:
## Get 101st molecule and search

In [52]:
with dm.without_rdkit_log():
    feats = np.stack(featurizer(mols[100]))

feats.dtype

dtype('float32')

In [53]:
feats

array([[ 1.29483337e+01,  1.29483337e+01,  1.16233744e-01,
        -7.70639777e-01,  7.36901402e-01,  1.97407398e+01,
         3.89882996e+02,  3.65691010e+02,  3.89150604e+02,
         1.44000000e+02,  0.00000000e+00,  3.13710541e-01,
        -4.65503365e-01,  4.65503365e-01,  3.13710541e-01,
         1.33333337e+00,  2.14814806e+00,  2.81481481e+00,
         3.54956932e+01,  9.74621105e+00,  2.44752812e+00,
        -2.40689683e+00,  2.35990405e+00, -2.51510048e+00,
         6.30159569e+00, -1.57973692e-01,  3.07205391e+00,
         1.88684714e+00,  8.35384338e+02,  1.94409466e+01,
         1.56510925e+01,  1.64070206e+01,  1.29483261e+01,
         9.13980770e+00,  9.51777172e+00,  6.97781372e+00,
         7.41424942e+00,  5.04411745e+00,  5.27921438e+00,
         3.75574517e+00,  3.96863031e+00, -2.33999991e+00,
         3.07205391e+00,  1.94223709e+01,  8.12850857e+00,
         4.15932465e+00,  1.63603958e+02,  9.63677311e+00,
         5.69392776e+00,  0.00000000e+00,  0.00000000e+0

In [69]:
feats.tolist()[0]

[12.948333740234375,
 12.948333740234375,
 0.11623374372720718,
 -0.7706397771835327,
 0.7369014024734497,
 19.740739822387695,
 389.88299560546875,
 365.6910095214844,
 389.1506042480469,
 144.0,
 0.0,
 0.3137105405330658,
 -0.46550336480140686,
 0.46550336480140686,
 0.3137105405330658,
 1.3333333730697632,
 2.1481480598449707,
 2.814814805984497,
 35.49569320678711,
 9.746211051940918,
 2.447528123855591,
 -2.4068968296051025,
 2.3599040508270264,
 -2.5151004791259766,
 6.301595687866211,
 -0.15797369182109833,
 3.072053909301758,
 1.8868471384048462,
 835.3843383789062,
 19.440946578979492,
 15.651092529296875,
 16.407020568847656,
 12.948326110839844,
 9.13980770111084,
 9.51777172088623,
 6.977813720703125,
 7.414249420166016,
 5.044117450714111,
 5.279214382171631,
 3.7557451725006104,
 3.968630313873291,
 -2.3399999141693115,
 3.072053909301758,
 19.42237091064453,
 8.128508567810059,
 4.159324645996094,
 163.6039581298828,
 9.636773109436035,
 5.693927764892578,
 0.0,
 0.0,
 5

In [70]:
query_results = index.query(vector = feats.tolist()[0], top_k = 100)

In [71]:
query_results

{'matches': [{'id': 'Molecule 64', 'score': 0.998084903, 'values': []},
             {'id': 'Molecule 69', 'score': 0.997916281, 'values': []},
             {'id': 'Molecule 41', 'score': 0.997634649, 'values': []},
             {'id': 'Molecule 49', 'score': 0.997574329, 'values': []},
             {'id': 'Molecule 27', 'score': 0.997482896, 'values': []},
             {'id': 'Molecule 80', 'score': 0.997457862, 'values': []},
             {'id': 'Molecule 35', 'score': 0.997237921, 'values': []},
             {'id': 'Molecule 50', 'score': 0.997204602, 'values': []},
             {'id': 'Molecule 42', 'score': 0.997161388, 'values': []},
             {'id': 'Molecule 68', 'score': 0.99692, 'values': []},
             {'id': 'Molecule 98', 'score': 0.996802092, 'values': []},
             {'id': 'Molecule 12', 'score': 0.996789217, 'values': []},
             {'id': 'Molecule 36', 'score': 0.996574163, 'values': []},
             {'id': 'Molecule 28', 'score': 0.996573091, 'values': [