In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git
# ! pip install sentencepiece

In [None]:
# ! pip install tensorflow==2.10.0

In [13]:
import random, pickle, json, os
from datasets import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import bitsandbytes as bnb
from peft import get_peft_model, LoraConfig

import sys
sys.path.append('../credentials/')
from HF_credentials import *

The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.


In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig

# Tokenizer

In [None]:
llm_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=256, add_prefix_space=False)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"

In [None]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

llm_tokenizer.apply_chat_template(chat, tokenize=False)

# Data

In [None]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'./data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'./data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'./data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'./data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'./data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'./data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    return conversations, input_smiles

In [None]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

In [None]:
class CombinedDataset(Dataset):
    def __init__(self, smiles_list, conversations, encoder_tokenizer, llm_tokenizer, max_length=256):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        conversation_tokenized = self.llm_tokenizer(self.conversations[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')

In [None]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
combined_dataset = CombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 2
combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

## Test

In [None]:
# x, y = next(iter(combined_loader))

# mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
# llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
#             torch_dtype=torch.bfloat16,
#             # quantization_config=bnb_config,
#             device_map="auto",
#             token=HF_CREDENTIALS
# )

# mol_encoder(**x)['last_hidden_state'];
# llm_model.model.embed_tokens(y['input_ids'].to('cuda'));

# Model

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR").to('cuda')

        # UNCOMMENT TO BRING DOWN FROM 15GB TO 7GB
        # bnb_config = BitsAndBytesConfig(
        #     load_in_4bit= True,
        #     bnb_4bit_quant_type= "nf4",
        #     bnb_4bit_compute_dtype= torch.bfloat16,
        #     bnb_4bit_use_double_quant= False,
        # )
        self.llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS)
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            # quantization_config=bnb_config,
            device_map="auto",
            token=HF_CREDENTIALS
        )

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

        self.linear_project = nn.Linear(self.mol_encoder.config.hidden_size, self.llm_config.hidden_size)

    def forward(self, smiles_tokens, text_tokens):
        # Encoder forward pass / Get SMILES embeddings
        mol_encoder_output = self.mol_encoder(**smiles_tokens)
        smiles_embedding = mol_encoder_output['last_hidden_state'][:,0,:] # torch.Size([batch, max_length, 384])

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.embed_tokens
        llm_embeddings = embedding_layer(text_tokens['input_ids'].to('cuda')) # torch.Size([batch, 1, max_length, 4096])

        return smiles_embedding, llm_embeddings

In [None]:
model = MolEncoderLLMPipeline()

In [None]:
# x, y = next(iter(combined_loader))
# model(x,y)

In [None]:
import sys
sys.exit()

# Model

In [None]:
class LoRA(nn.Module):
    def __init__(self, embed_dim, rank, alpha, dropout_rate=0.05):
        super(LoRA, self).__init__()
        self.rank = rank
        self.alpha = alpha # Scaling factor for LoRA

        # Low-rank matrices A and B
        self.A = nn.Parameter(torch.randn(embed_dim, rank))
        self.B = nn.Parameter(torch.randn(rank, embed_dim))

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, original_weight):
        delta_weight = self.alpha * torch.matmul(self.A, self.B)
        delta_weight = self.dropout(delta_weight)
        return original_weight + delta_weight

In [53]:
# TODO
# - Add special tokens
# - Convert weight of ChemBERTa to bfloat16. Done
# - Add projection layer for mol embeddings. Done

class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, 
                #  mol_encoder, 
                #  llm_model, 
                #  llm_embedding_dim, 
                 lora_rank=32, 
                 lora_alpha=64,
                 cache_dir=None,
                 ):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained(
            "DeepChem/ChemBERTa-77M-MTR", 
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir)

        # llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1')
        self.llm_model = AutoModelForCausalLM.from_pretrained(
            'mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            device_map="auto",
            token=HF_CREDENTIALS,
            cache_dir=cache_dir,
        )
        self.llm_model.config.use_cache = False
        self.llm_model.config.pretraining_tp = 1

        # Initialize LoRA layers for Mistral
        self.lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            lora_dropout=0.05,
            target_modules=["q_proj","k_proj","v_proj","o_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        # projection layer for molecular vectors
        self.mol_prj = nn.Linear(
            self.mol_encoder.config.hidden_size, 
            self.llm_model.config.hidden_size,
            dtype=torch.bfloat16)
        
        # self.lora_layers = nn.ModuleList([
        #     LoRA(llm_config.hidden_size, lora_rank, lora_alpha) for _ in range(len(self.llm_model.encoder.layer))
        # ])

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False
        
        # Apply LoRA modification
        self.llm_model = get_peft_model(self.llm_model, self.lora_config)

    def forward(self, smiles_tokens, input_ids):
        # Encoder forward pass / Get SMILES embeddings
        mol_encoder_output = self.mol_encoder(smiles_tokens)
        mol_embeddings = mol_encoder_output.last_hidden_state[:,:1,:]
        mol_embeddings = self.mol_prj(mol_embeddings)

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.model.embed_tokens
        llm_embeddings = embedding_layer(input_ids)

        # Concatenate encoder and LLM embeddings
        combined_embeddings = torch.cat((mol_embeddings, llm_embeddings), dim=1) #concat([llm_embeddings])

        # Pass through Mistral's transformer layers with LoRA adjustments
        extended_attention_mask = torch.ones(combined_embeddings.shape[0], combined_embeddings.shape[1], device=combined_embeddings.device)
        hidden_states = combined_embeddings
        # for i, layer_module in enumerate(self.llm_model.encoder.layer):
        #     layer_output = layer_module(hidden_states, attention_mask=extended_attention_mask)[0]
        #     # Apply LoRA modification
        #     qkv_weights = [self.lora_layers[i](w) for w in layer_module.attention.self.query.weight, layer_module.attention.self.key.weight, layer_module.attention.self.value.weight]
        #     layer_module.attention.self.query.weight, layer_module.attention.self.key.weight, layer_module.attention.self.value.weight = qkv_weights
        #     hidden_states = layer_output
        hidden_states = self.llm_model(inputs_embeds=hidden_states, attention_mask=extended_attention_mask)

        return hidden_states

In [54]:
model = MolEncoderLLMPipeline(cache_dir="hf_cache/")

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  5.20it/s]


In [23]:
mol_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR", cache_dir="hf_cache/")
text_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", cache_dir="hf_cache/")



In [26]:
text = "Name this molecules"
smiles = "cccccc"
text_ids = text_tokenizer(text, return_tensors="pt").input_ids
smiles_ids = mol_tokenizer(smiles, return_tensors="pt").input_ids
print(text_ids, smiles_ids)

tensor([[    1,  6620,   456, 12160, 21649]]) tensor([[12, 15, 15, 15, 15, 15, 15, 13]])


In [43]:
embedding_layer = model.llm_model.model.model.embed_tokens
llm_embeddings = embedding_layer(text_ids)
llm_embeddings.shape

torch.Size([1, 5, 4096])

In [42]:
model.mol_encoder(smiles_ids).last_hidden_state.shape

torch.Size([1, 8, 384])

In [55]:
model(smiles_ids, text_ids)

CausalLMOutputWithPast(loss=None, logits=tensor([[[-6.4375, -6.3438, -2.4062,  ..., -6.0625, -4.4688, -2.0469],
         [-5.7188, -5.7500, -0.0967,  ..., -4.5312, -3.4531, -4.0625],
         [-8.4375, -9.4375, -3.8281,  ..., -4.8750, -5.2500, -5.1562],
         [-7.3438, -7.5625, -4.3125,  ..., -5.4062, -5.2812, -6.0312],
         [-7.2812, -7.5938, -1.6719,  ..., -6.0625, -7.6875, -4.9688],
         [-9.0000, -9.7500, -0.1021,  ..., -6.0625, -8.3125, -5.0625]]],
       grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)

In [57]:
for name, layer in model.named_parameters():
    print(name, "---", layer.requires_grad)

mol_encoder.embeddings.word_embeddings.weight --- False
mol_encoder.embeddings.position_embeddings.weight --- False
mol_encoder.embeddings.token_type_embeddings.weight --- False
mol_encoder.embeddings.LayerNorm.weight --- False
mol_encoder.embeddings.LayerNorm.bias --- False
mol_encoder.encoder.layer.0.attention.self.query.weight --- False
mol_encoder.encoder.layer.0.attention.self.query.bias --- False
mol_encoder.encoder.layer.0.attention.self.key.weight --- False
mol_encoder.encoder.layer.0.attention.self.key.bias --- False
mol_encoder.encoder.layer.0.attention.self.value.weight --- False
mol_encoder.encoder.layer.0.attention.self.value.bias --- False
mol_encoder.encoder.layer.0.attention.output.dense.weight --- False
mol_encoder.encoder.layer.0.attention.output.dense.bias --- False
mol_encoder.encoder.layer.0.attention.output.LayerNorm.weight --- False
mol_encoder.encoder.layer.0.attention.output.LayerNorm.bias --- False
mol_encoder.encoder.layer.0.intermediate.dense.weight --- Fals

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Assume model and criterion are defined elsewhere
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
epochs = 5
model.train()
for epoch in range(epochs):
    for batch in combined_loader:
        smiles_data, conversation_data = batch
        smiles_input_ids, smiles_attention_mask = smiles_data['input_ids'].squeeze(1), smiles_data['attention_mask'].squeeze(1)
        convo_input_ids, convo_attention_mask = conversation_data['input_ids'].squeeze(1), conversation_data['attention_mask'].squeeze(1)

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(smiles_input_ids, convo_input_ids) # Adjust if your model's `forward` method expects more parameters
        
        # Define labels appropriately
        labels = ... # Define how to obtain these
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Eval

In [None]:
model.config.use_cache = True