In [1]:
import os
import sys

if os.path.abspath('../src') not in sys.path:
    sys.path.append(os.path.abspath('../src'))

import time
from pathlib import Path

import pandas as pd
from datasets import load_dataset
from multiprocess import set_start_method
from tqdm.auto import tqdm

from interface import MolRSmilesEmbedder

try:
    set_start_method("spawn")
except RuntimeError as e:
    print(e)

from datasets import disable_caching

# disable_caching()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_PATH = Path('../saved/tag_1024')
DATA_PATH = Path("../../data/CS2/all.csv")  # Path to the dataset

embedder = MolRSmilesEmbedder(MODEL_PATH)
dataset = load_dataset('csv', data_files=str(DATA_PATH), split='train')
emb_dataset = dataset.select(range(10000)).map(embedder, batched=True, batch_size=256, num_proc=4)
dimension = len(emb_dataset[0]['vector'])
print(f"embedding dimension: {dimension}")

NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported.

In [4]:
def time_vs_batch_size_datasets(embedding_processor, dataset, batch_sizes = [1024]):
    times = []
    for batch_size in tqdm(batch_sizes, leave=False, position=0):
        start = time.time()
        _ = dataset.map(
            embedding_processor,
            batched=True,
            batch_size=batch_size,
            with_rank=True,
            num_proc=4,
            )
        end = time.time()
        times.append(end - start)
    res = pd.DataFrame({'batch_size': batch_sizes, 'time': times})
    res['time_per_mol'] = res['time'] / len(dataset)
    return res

In [None]:
results = time_vs_batch_size_datasets(embedder, dataset.select(range(10000)), batch_sizes=[8, 16, 32, 64, 128, 256, 512, 1024, 2048])

In [None]:
results.plot(x='batch_size', y='time_per_mol', logx=True, logy=True)

In [8]:
def convert_all(embedder, cs_numbers = [1, 2, 3, 5, 6, 7, 8]):
    for ii, cs_number in enumerate(cs_numbers):
        print(f"Converting CS{cs_number}, {ii+1}/{len(cs_numbers)}")
        DATA_PATH = Path(f"../../data/CS{cs_number}/all.csv")  # Path to the dataset
        dataset = load_dataset('csv', data_files=str(DATA_PATH), split='train')
        emb_dataset = dataset.map(embedder, batched=True, batch_size=256, num_proc=4)
        emb_dataset.save_to_disk(DATA_PATH.parent / 'all_embeddings')

In [None]:
convert_all(embedder, [2, 3, 5, 6, 7, 8])