In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from torch.utils.data import Dataset, DataLoader

In [None]:
class ChemBERTaEmbedder:
    def __init__(self, model_name="seyonec/ChemBERTa-zinc-base-v1"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, smiles_input):
        embeddings = []
        for smile in smiles_input:
            # print(smile)
            input_ids = self.tokenizer.encode(smile, add_special_tokens=True, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model(input_ids.cpu())
            last_hidden_states = outputs.last_hidden_state
            embedding = last_hidden_states[:,0,:].squeeze()
            embeddings.append(embedding)
        return embeddings

In [2]:
tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
model = AutoModel.from_pretrained('DeepChem/ChemBERTa-77M-MTR')

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
smiles_list = ['CC(=O)OC1/C=C/C(C)=C/CC(O)/C=C/C(C)=C/C(NC(=O)C(C)=O)C2(C)C(=O)OC(C1)C(C)C2=O',
 'O=C(CN1CC[C@H](C2=CC=C(F)C=C2)[C@@H](COC2=CC=C3OCOC3=C2)C1)C1=CC=C(F)C=C1',]

In [16]:
class SmilesDataset(Dataset):
    def __init__(self, smiles_list, tokenizer, max_length):
        self.smiles_list = smiles_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        encoding = self.tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        return {key: tensor[0] for key, tensor in encoding.items()}

# Usage:
dataset = SmilesDataset(smiles_list, tokenizer, max_length=512)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [17]:
next(iter(dataloader))

{'input_ids': tensor([[12, 19, 22,  ...,  0,  0,  0],
         [12, 16, 16,  ...,  0,  0,  0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]])}

In [18]:
outputs = model(**next(iter(dataloader)))

In [22]:
outputs['last_hidden_state'].shape

torch.Size([2, 512, 384])

In [4]:
# Tokenize the SMILES and convert to PyTorch tensors
inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True)

# Pass the inputs through the model
outputs = model(**inputs)

# Get the CLS embeddings
cls_embeddings = outputs.last_hidden_state[:, 0, :]

In [7]:
inputs

{'input_ids': tensor([[12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 39, 16, 22, 16, 39, 16, 17, 16,
         18, 22, 16, 39, 16, 16, 17, 19, 18, 39, 16, 22, 16, 39, 16, 17, 16, 18,
         22, 16, 39, 16, 17, 23, 16, 17, 22, 19, 18, 16, 17, 16, 18, 22, 19, 18,
         16, 21, 17, 16, 18, 16, 17, 22, 19, 18, 19, 16, 17, 16, 20, 18, 16, 17,
         16, 18, 16, 21, 22, 19, 13],
        [12, 19, 22, 16, 17, 16, 23, 20, 16, 16, 16, 17, 16, 21, 22, 16, 16, 22,
         16, 17, 27, 18, 16, 22, 16, 21, 18, 16, 17, 16, 19, 16, 21, 22, 16, 16,
         22, 16, 26, 19, 16, 19, 16, 26, 22, 16, 21, 18, 16, 20, 18, 16, 20, 22,
         16, 16, 22, 16, 17, 27, 18, 16, 22, 16, 20, 13,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [6]:
cls_embeddings.shape

torch.Size([2, 384])