In [1]:
import pickle

import pandas as pd
import torch
from rxn.chemutils.miscellaneous import canonicalize_any

# Canonicalization utils
from rxn.chemutils.utils import remove_atom_mapping
from sklearn.utils import gen_batches
from tqdm.notebook import tqdm

# Mapping data to embeddings.

We do both of these in batch but keeping an index, so we can map back to dict.

In [2]:
# First load clean dataset from `00_preprocess.ipynb`
with open("../../data/processed/sg_db_clean.bin", "rb") as f:
    sg_db = pickle.load(f)

## RXNFP mapping.

We'll just use rxnfp (https://rxn4chemistry.github.io/rxnfp/)

In [3]:
from rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator,
    generate_fingerprints,
    get_default_model_and_tokenizer,
)

rxn_model, rxn_tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(rxn_model, rxn_tokenizer)


def preproc_rxn(smi):
    u = remove_atom_mapping(smi)
    try:
        return canonicalize_any(u)
    except:
        return ""


def embed_rxns(rxns):
    rxns_pc = [preproc_rxn(r) for r in rxns]
    fp = rxnfp_generator.convert_batch(rxns_pc)
    return torch.tensor(fp)

In [4]:
# First select rxn_setup only
sg_db_setup_d = {i: b for i, b in enumerate(sg_db) if b["sgm_cls"] == "reaction set-up"}
sg_db_setup = [b for b in sg_db_setup_d.values()]

# Select a batch size
bs = 2048

batches = gen_batches(len(sg_db_setup), batch_size=bs)
len_batch = len(sg_db_setup) // bs
rxn_embeds = []
for b in tqdm(batches, total=len_batch):
    # Select only
    sentences = [s["rxn_smi"] for s in sg_db_setup[b]]
    embeds = embed_rxns(sentences)
    torch.cuda.empty_cache()
    rxn_embeds.append(embeds)

rxn_embeds = torch.concat(rxn_embeds)
torch.save(rxn_embeds, "../../data/rxn_embeds_clean.pt")

  0%|          | 0/450 [00:00<?, ?it/s]

In [5]:
rxn_embeds = torch.load("../../data/rxn_embeds_clean.pt")

---
## Natural Language Embeddings. Here we use an opensource model

From https://huggingface.co/spaces/mteb/leaderboard, it looks like `BAAI/bge-large-en-v1.5` is the best currently.

In [3]:
import gc

from transformers import AutoModel, AutoTokenizer

device = "cuda:0"
embedding_model = "BAAI/bge-large-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
model = AutoModel.from_pretrained(embedding_model).to(device)
model.eval()


def embed_batch(sentences):
    # Tokenize sentences
    encoded_input = tokenizer(
        sentences, padding=True, truncation=True, return_tensors="pt", max_length=256
    ).to(device)

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]

    del model_output, encoded_input

    gc.collect()
    torch.cuda.empty_cache()

    # normalize embeddings
    sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings.to("cpu")


bs = 512
batches = gen_batches(len(sg_db), batch_size=bs)
len_batch = len(sg_db) // bs

segm_embeds = []
for i, b in tqdm(enumerate(batches), total=len_batch):
    # Sentences we want sentence embeddings for
    sentences = [s["txt_sgm"] for s in sg_db[b]]
    embeds = embed_batch(sentences)
    segm_embeds.append(embeds)

segm_embeds = torch.concat(segm_embeds)
torch.save(segm_embeds, "../../data/processed/embeds/segment_embeds.pt")

  0%|          | 0/5052 [00:00<?, ?it/s]

## Finally, split the tensor into subsets, to save memory for next visualization steps.

In [3]:
segm_embeds = torch.load("../../data/processed/embeds/segment_embeds.pt")

In [4]:
%%time
import re

import numpy as np

# Calculate for each segment class
sgm_cls = np.array([s["sgm_cls"] for s in sg_db])
del sg_db  # For memory

classes = ["reaction set-up", "work-up", "purification", "analysis"]
for i, cl in tqdm(enumerate(classes), total=len(classes)):
    idx = np.ones(len(sgm_cls))
    idx = np.where(np.array(sgm_cls) == cl, idx, 0)

    print(f"There are {idx.sum()} {cl}s")
    slice = segm_embeds[idx == 1]
    print(slice.shape)

    # Load subset and make map
    cl = re.sub(" |-", "_", cl)
    torch.save(slice, f"../../data/processed/embeds/segm_embs_{cl}.pt")

  0%|          | 0/4 [00:00<?, ?it/s]

There are 921886.0 reaction set-ups
torch.Size([921886, 1024])
There are 787540.0 work-ups
torch.Size([787540, 1024])
There are 452965.0 purifications
torch.Size([452965, 1024])
There are 417692.0 analysiss
torch.Size([417692, 1024])
CPU times: user 7.37 s, sys: 12.3 s, total: 19.7 s
Wall time: 12.2 s
