# Extracting ChemBERTa embeddings from the ChemBERTa model
ChemBERTa `[Chithrananda et al., 2020]` is used to generate embedding vectors for each SMILES graph

## Setting up environment

In [2]:
import os

os.chdir('/home/yz979/code/kaggle-perturbation/')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [3]:
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from tqdm import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer

## Loading Dataset and ChemBERTa Model
The preprocessed `de_train.h5ad` and `id_map.h5ad` can be loaded using the `sc.read_h5ad` function. The ChemBERTa model can be loaded using the `AutoModelForMaskedLM` function from the `transformers` library.

In [17]:
de_train = sc.read_h5ad('data/de_train.h5ad')
id_map = sc.read_h5ad('data/id_map.h5ad')

print(de_train)
print(id_map)

AnnData object with n_obs × n_vars = 614 × 18211
    obs: 'cell_type', 'sm_name', 'sm_lincs_id', 'SMILES', 'control'
    obsm: 'chemberta', 'multivi'
AnnData object with n_obs × n_vars = 255 × 18211
    obs: 'cell_type', 'sm_name', 'sm_lincs_id', 'SMILES', 'control'
    obsm: 'chemberta', 'multivi'


In [11]:
chemberta = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")

Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Extract embeddings for each SMILES graph
For each SMILES graph, the ChemBERTa model is used to generate the embedding vector. The embedding vector is the output of the `[CLS]` token. The embedding vector is stored in the `obsm['chemberta']` slot of the `de_train` object.

In [14]:
smiles_de = de_train.obs['SMILES'].values

# generate embeddings for all smiles
embeddings = []
for smile in tqdm(smiles_de):
    encoded_input = tokenizer(smile, return_tensors="pt", padding=True, truncation=True)
    model_output = chemberta(**encoded_input).logits
    model_output = model_output.squeeze(0)[0].detach()
    embeddings.append(model_output)
embeddings = torch.stack(embeddings)

de_train.obsm['chemberta'] = embeddings.numpy()
de_train.write_h5ad('data/de_train.h5ad')

100%|██████████| 614/614 [00:03<00:00, 162.88it/s]


In [18]:
smiles_id = id_map.obs['SMILES'].values

# generate embeddings for all smiles
embeddings = []
for smile in tqdm(smiles_id):
    encoded_input = tokenizer(smile, return_tensors="pt", padding=True, truncation=True)
    model_output = chemberta(**encoded_input).logits
    model_output = model_output.squeeze(0)[0].detach()
    embeddings.append(model_output)
embeddings = torch.stack(embeddings)

id_map.obsm['chemberta'] = embeddings.numpy()
id_map.write_h5ad('data/id_map.h5ad')

100%|██████████| 255/255 [00:01<00:00, 155.67it/s]
