In [5]:
# Given Imports
import torch
import re
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
from datasets import load_dataset

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
torch.cuda.empty_cache()

### Load Encoders and Tokenizers

In [8]:
# Protein encoder
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)

# Molecule encoder
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

max_prot_input_size = prot_model.config.max_position_embeddings
max_mol_input_size = mol_model.config.max_position_embeddings-2

### Load Dataset

In [9]:
dataset = load_dataset("jglaser/binding_affinity")
dataset = dataset["train"].train_test_split(train_size=0.001)['train']

### Preprocess Data

In [10]:
def preprocess_function(example):
    example['seq'] = " ".join(re.sub(r"[UZOB]", "X", example['seq']))
    return example

dataset = dataset.map(preprocess_function)

Map:   0%|          | 0/1836 [00:00<?, ? examples/s]

### Encoding and Tokenizing Functions

In [11]:
# Define Encoding functions
def encode_batch(batch, tokenizer, model, max_input_size):
    tokens = tokenizer(batch, padding=True, truncation=True, max_length=max_input_size, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**tokens.to(device)).pooler_output
    # representations = outputs.last_hidden_state.mean(dim=1)
    return outputs.cpu()

def encode_sequences(prot_seq, mol_smiles, mol_batch_size=16, prot_batch_size=2):
    # Encode in batches to prevent out-of-memory errors
    prot_representations = []
    mol_representations = []
    
    mol_loader = DataLoader(mol_smiles, batch_size=mol_batch_size, shuffle=False)
    for i, mol_batch in enumerate(mol_loader, 1):
        if i % 20 == 0:
            print(f"\rEncoding molecule batch {i}/{len(mol_loader)}...", end="")
        mol_representations.append(encode_batch(mol_batch, mol_tokenizer, mol_model, max_mol_input_size))
    print("done!")
    
    mol_model.to("cpu")
    torch.cuda.empty_cache()
    
    prot_loader = DataLoader(prot_seq, batch_size=prot_batch_size, shuffle=False)
    for i, prot_batch in enumerate(prot_loader, 1):
        print(f"\rEncoding protein batch {i}/{len(prot_loader)}...", end="")
        prot_representations.append(encode_batch(prot_batch, prot_tokenizer, prot_model, max_prot_input_size))
        torch.cuda.empty_cache()
    print("done!")
    return torch.cat(prot_representations, dim=0), torch.cat(mol_representations, dim=0)

In [12]:
def create_tensor_dataset(dataset):
    proteins, smiles, affinities = dataset["seq"], dataset["smiles_can"], dataset["affinity"]
    prot_rep, chem_rep = encode_sequences(proteins, smiles)
    return TensorDataset(prot_rep, chem_rep, torch.tensor(affinities))

In [13]:
print("encoding data...")
tensor_dataset = create_tensor_dataset(dataset)

encoding data...
Encoding molecule batch 100/115...done!
Encoding protein batch 84/918...

KeyboardInterrupt: 

In [None]:
len(tensor_dataset)

1836

In [None]:
torch.save(tensor_dataset, r"C:\Users\tatwo\Downloads\encoded_data_final")

In [None]:
tokens=mol_tokenizer(["CCCCCCCCCCCCCCCCCCCC(=O)O"], padding=True, truncation=True, max_length=512, return_tensors='pt')
mol_model(**tokens)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.4373, -0.6694, -1.0074,  ..., -0.9850,  0.2792,  1.4708],
         [-0.1302, -1.5081, -0.6598,  ...,  0.2113,  1.4561,  0.2045],
         [ 0.8827, -0.3146, -1.0314,  ...,  1.0429,  1.5200,  0.1877],
         ...,
         [-0.3223, -1.1277,  0.2969,  ..., -1.9156,  0.3809,  1.1264],
         [ 0.0172,  1.5409, -0.9348,  ..., -0.8201,  0.8186, -0.6957],
         [-1.4186, -1.0689,  0.3970,  ..., -0.7573,  0.7821,  0.7213]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 2.9266e-01, -2.0106e-01, -2.5144e-01,  2.7638e-01, -5.9401e-01,
          4.7673e-01,  4.0142e-01,  1.5805e-01, -2.2580e-01, -4.7935e-01,
         -6.7652e-01,  1.6724e-02,  9.4531e-01,  3.6973e-01,  3.7976e-01,
         -2.0460e-01, -1.0615e-01, -3.7676e-01, -1.8930e-01, -2.3977e-01,
          8.3878e-03,  2.7116e-01, -2.4827e-01,  2.4240e-01, -6.5158e-01,
         -1.0794e-01, -2.9413e-01, -1.7536e-01,  4.4841e-02, -9.339