In [26]:
# Given Imports
import torch
import torch.nn as nn
import torch.optim as optim
import statistics
import re
import os
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer, AdamW
from datasets import load_dataset

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
torch.cuda.empty_cache()

### Load Encoders and Tokenizers

In [29]:
# Protein encoder
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)

# Molecule encoder
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

max_prot_input_size = prot_model.config.max_position_embeddings
max_mol_input_size = mol_model.config.max_position_embeddings

### Load Dataset

In [30]:
dataset = load_dataset("jglaser/binding_affinity")

# split_dataset = dataset["train"].train_test_split(train_size=0.8)
# train_dataset = split_dataset['train']
# test_dataset = split_dataset['test']

train_dataset = dataset["train"].train_test_split(train_size=0.0008)['train']
test_dataset = dataset["train"].train_test_split(test_size=0.0002)['test']

### Preprocess Data

In [31]:
def preprocess_function(example):
    example['seq'] = re.sub(r"[UZOB]", "X", example['seq'])
    return example

train_dataset = train_dataset.map(preprocess_function)
test_dataset = test_dataset.map(preprocess_function)

Map:   0%|          | 0/1469 [00:00<?, ? examples/s]

Map:   0%|          | 0/368 [00:00<?, ? examples/s]

In [38]:
# Define Encoding functions
def encode_batch(batch, tokenizer, model, max_input_size):
    tokens = tokenizer(batch, padding=True, truncation=True, max_length=max_input_size, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**tokens.to(device))
    representations = outputs.last_hidden_state.mean(dim=1)
    return representations.cpu()

def encode_sequences(prot_seq, mol_smiles, batch_size=32):
    # Encode in batches to prevent out-of-memory errors
    prot_representations = []
    chem_representations = []
    
    prot_loader = DataLoader(prot_seq, batch_size=batch_size, shuffle=False)
    for prot_batch in prot_loader:
        prot_representations.append(encode_batch(prot_batch, prot_tokenizer, prot_model, max_prot_input_size))
        
    chem_loader = DataLoader(mol_smiles, batch_size=batch_size, shuffle=False)
    for chem_batch in chem_loader:
        chem_representations.append(encode_batch(chem_batch, mol_tokenizer, mol_model, max_mol_input_size))
        
    return torch.cat(prot_representations, dim=0), torch.cat(chem_representations, dim=0)

In [42]:
def create_tensor_dataset(dataset):
    proteins, smiles, affinities = dataset["seq"], dataset["smiles_can"], dataset["affinity"]
    prot_rep, chem_rep = encode_sequences(proteins, smiles, 8)
    return TensorDataset(prot_rep, chem_rep, torch.tensor(affinities))

train_tensor_dataset = create_tensor_dataset(train_dataset)
test_tensor_dataset = create_tensor_dataset(test_dataset)

train_loader = DataLoader(train_tensor_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_tensor_dataset, batch_size=32)

In [34]:
torch.save(train_tensor_dataset, 'data\\train_data_processed')
torch.save(test_tensor_dataset, 'data\\test_data_processed')

In [35]:
train_tensor_dataset = torch.load('data\\test_data_processed')
train_tensor_dataset = torch.load('data\\train_data_processed')