## Imports

In [None]:
%pip install pinecone-client datasets PyTDC rdkit datamol pandas numpy molfeat mols2grid


In [None]:
import tqdm
import pandas as pd
import numpy as np
import itertools
from tdc.generation import MolGen
from pinecone import Pinecone, ServerlessSpec
import datamol as dm
import molfeat
from molfeat.calc import FPCalculator, RDKitDescriptors2D
from molfeat.trans import MoleculeTransformer
from molfeat.store.modelstore import ModelStore
from molfeat.trans.pretrained import PretrainedMolTransformer, GraphormerTransformer
import mols2grid


## Load data

In [None]:
from tdc.generation import MolGen
data = MolGen(name = 'ZINC')
split = data.get_split() 

In [None]:
mols = dm.convert.from_df(split['test'])

## Featurize

In [None]:
df = dm.descriptors.batch_compute_many_descriptors(mols, properties_fn=None, add_properties=True, n_jobs=-1, batch_size='auto', progress=False, progress_leave=True)

In [None]:
df['mols'] = mols

In [None]:
df['mol_id'] = [f'Molecule {i}' for i, _ in enumerate(df.mw)]

In [None]:
# # List all available featurizers
store = ModelStore()
store.available_models
model_card = store.search(name='pcqm4mv2_graphormer_base')[0]

In [None]:
store.available_models

In [None]:
# featurizer = GraphormerTransformer(kind='pcqm4mv2_graphormer_base', dtype=np.float32, pooling='mean', max_length=None, concat_layers=-1, ignore_padding=True, version=None)
featurizer = PretrainedMolTransformer(kind='ChemGPT-1.2B', dtype=np.float32, pooling='mean', max_length=None, concat_layers=-1, ignore_padding=True, version=None)

In [None]:
# calc = FPCalculator("ecfp")
calc = RDKitDescriptors2D(replace_nan=True)

In [None]:
featurizer = MoleculeTransformer(calc, dtype=np.float32)

with dm.without_rdkit_log():
    feats = np.stack(featurizer(mols[:100]))

feats.shape

## Configure Pinecone index

In [None]:
pc = Pinecone(api_key= "")

In [None]:
index = pc.Index('molsearch')

### Upsert vectors

In [None]:
id_list = [f'Molecule {i}' for i, _ in enumerate(feats)]

In [None]:
payload = zip(id_list, feats)

In [None]:
def chunks(iterable, batch_size=100):
    """A helper function to break an iterable into chunks of size batch_size."""
    it = iter(iterable)
    chunk = tuple(itertools.islice(it, batch_size))
    while chunk:
        yield chunk
        chunk = tuple(itertools.islice(it, batch_size))

In [None]:
for ids_vectors_chunk in chunks(payload, batch_size=10):
    index.upsert(vectors=ids_vectors_chunk) 

## Similarity search

In [None]:
## Get 101st molecule and search

In [None]:
with dm.without_rdkit_log():
    feats = np.stack(featurizer(mols[100]))


In [None]:
query_results = index.query(vector = feats.tolist()[0], top_k = 100)

In [None]:
result_id = [query_results['matches'][id_no]['id'] for id_no in range(len(query_results['matches'])) ]

In [None]:
len(result_id)

## Analyze hit similarity

In [None]:
## Can estimate precision (> Tanimoto threshold in the returned set)
## Recall (Tanimoto for all in the dataset and see if retrieved top 100)
## Tanimoto as ground truth here


In [None]:
mols[100]

In [None]:
df_results = df[df['mol_id'].isin(result_id)]

In [None]:
df_results

In [None]:
mols2grid.display(df_results, mol_col = 'mols', subset = ['mol_id', 'img', 'n_heavy_atoms' ])