In [1]:
# Given Imports
import torch
import re
import statistics
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
from datasets import load_dataset, Dataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
torch.cuda.empty_cache()

### Load Encoders and Tokenizers

In [4]:
# Protein encoder
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)

# Molecule encoder
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

# max_prot_input_size = prot_model.config.max_position_embeddings
max_prot_input_size = 3200 #capped at 3000 since tokens longer than 3000 use way too much vram
max_mol_input_size = 278 #capped at 278 since the longest tokenized smiles sequence in the dataset has a length of 278. 

### Load Dataset

In [5]:
dataset = load_dataset("jglaser/binding_affinity")['train']

In [6]:
dataset100k = dataset.select(range(100000))
dataset10k = dataset.select(range(10000))
dataset1k = dataset.select(range(1000))

### Preprocess & Tokenize Data

Replace irregular amino acids in the dataset's protein sequences with "X", which is necessary for accurate tokenization and encodings from the ProtBERT model
Parallelized map function is used

In [7]:
def preprocess_function(example):
    import re
    example['seq'] = " ".join(re.sub(r"[UZOB]", "X", example['seq']))
    return example

dataset100k = dataset100k.map(preprocess_function, num_proc=8)
dataset10k = dataset10k.map(preprocess_function, num_proc=8)
dataset1k = dataset1k.map(preprocess_function, num_proc=8)

In [8]:
def tokenize_prot(example):
    from transformers import BertTokenizer
    prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    return prot_tokenizer(example['seq'], padding=True, truncation=True, max_length=max_prot_input_size, return_tensors='pt')

def tokenize_mol(example):
    from transformers import RobertaTokenizer
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    return mol_tokenizer(example['smiles_can'], padding=True, return_tensors='pt')
    # return mol_tokenizer(example['smiles_can'], padding=True, truncation=True, max_length=max_mol_input_size, return_tensors='pt')

### Encoding and Tokenizing Functions

In [9]:
# Define Encoding functions
def encode_batch(batch, tokenizer, model, max_input_size):
    tokens = tokenizer(batch, padding=True, truncation=True, max_length=max_input_size, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**tokens.to(device)).pooler_output
    # representations = outputs.last_hidden_state.mean(dim=1)
    return outputs.cpu()

def encode_sequences(prot_seq, mol_smiles, mol_batch_size=16, prot_batch_size=2):
    # Encode in batches to prevent out-of-memory errors
    prot_representations = []
    mol_representations = []
    
    mol_loader = DataLoader(mol_smiles, batch_size=mol_batch_size, shuffle=False)
    for i, mol_batch in enumerate(mol_loader, 1):
        if i % 20 == 0:
            print(f"\rEncoding molecule batch {i}/{len(mol_loader)}...", end="")
        mol_representations.append(encode_batch(mol_batch, mol_tokenizer, mol_model, max_mol_input_size))
    print("done!")
    
    mol_model.to("cpu")
    torch.cuda.empty_cache()
    
    prot_loader = DataLoader(prot_seq, batch_size=prot_batch_size, shuffle=False)
    for i, prot_batch in enumerate(prot_loader, 1):
        print(f"\rEncoding protein batch {i}/{len(prot_loader)}...", end="")
        prot_representations.append(encode_batch(prot_batch, prot_tokenizer, prot_model, max_prot_input_size))
        torch.cuda.empty_cache()
    print("done!")
    return torch.cat(prot_representations, dim=0), torch.cat(mol_representations, dim=0)

In [10]:
def create_tensor_dataset(dataset):
    proteins, smiles, affinities = dataset["seq"], dataset["smiles_can"], dataset["affinity"]
    prot_rep, chem_rep = encode_sequences(proteins, smiles)
    return TensorDataset(prot_rep, chem_rep, torch.tensor(affinities))

In [11]:
print("encoding data...")
tensor_dataset = create_tensor_dataset(dataset100k)

encoding data...
Encoding molecule batch 6240/6250...done!
Encoding protein batch 50000/50000...done!


In [12]:
tensor_dataset[0]

(tensor([-0.2664,  0.2810, -0.2537,  ...,  0.2681,  0.2509, -0.2737]),
 tensor([ 2.9266e-01, -2.0106e-01, -2.5144e-01,  2.7638e-01, -5.9401e-01,
          4.7673e-01,  4.0143e-01,  1.5805e-01, -2.2580e-01, -4.7935e-01,
         -6.7652e-01,  1.6724e-02,  9.4531e-01,  3.6973e-01,  3.7976e-01,
         -2.0460e-01, -1.0615e-01, -3.7676e-01, -1.8930e-01, -2.3976e-01,
          8.3889e-03,  2.7116e-01, -2.4827e-01,  2.4240e-01, -6.5158e-01,
         -1.0794e-01, -2.9413e-01, -1.7536e-01,  4.4842e-02, -9.2611e-05,
         -1.0005e-01,  5.1793e-02,  4.2853e-01, -1.9518e-01, -3.7019e-01,
         -6.4103e-01,  6.4570e-01, -3.3646e-01,  3.8487e-01, -1.3384e-01,
          4.4586e-02,  6.1111e-01, -2.3334e-01, -5.1104e-01, -3.2470e-01,
         -3.8882e-02,  2.2602e-01, -2.6721e-01, -7.1416e-01, -9.9216e-02,
         -6.6145e-01, -6.7018e-02,  6.7712e-01,  4.4566e-01, -6.1443e-01,
         -7.4043e-01, -3.6978e-01,  3.1468e-01,  5.8633e-01, -6.9997e-01,
         -8.4047e-01,  7.7237e-01,  6.069

### Export embeddings

In [5]:
import os

# Get the user's home directory
home_dir = os.path.expanduser('~')

# Construct the path to the Documents folder
documents_folder = os.path.join(home_dir, 'Documents')

# Construct the full path for the ONNX files
tensor_dataset_output_path = os.path.join(documents_folder, "WELP-PLAPT/data", "tensor_dataset.data")
tensor_dataset_output_path_json = os.path.join(documents_folder, "WELP-PLAPT/data", "tensor_dataset.json")

# Ensure the Encoders directory exists
os.makedirs(os.path.join(documents_folder, "WELP-PLAPT", "data"), exist_ok=True)

In [14]:
torch.save(tensor_dataset, tensor_dataset_output_path)

### Validate Data

In [15]:
# Function to compare two datasets
def compare_datasets(dataset1, dataset2):
    # Check if both are the same type
    if type(dataset1) != type(dataset2):
        return False

    # If datasets are dictionaries
    if isinstance(dataset1, dict):
        if dataset1.keys() != dataset2.keys():
            return False
        for key in dataset1:
            if not torch.equal(dataset1[key], dataset2[key]):
                return False

    # If datasets are lists or tuples
    elif isinstance(dataset1, (list, tuple)):
        if len(dataset1) != len(dataset2):
            return False
        for item1, item2 in zip(dataset1, dataset2):
            if not torch.equal(item1, item2):
                return False

    # Add other comparisons if your dataset is of a different type
    return True

In [16]:
import os
import torch

# Check if the file exists
if not os.path.exists(tensor_dataset_output_path):
    print("File does not exist:", tensor_dataset_output_path)
else:
    print("File found:", tensor_dataset_output_path)

    # Check file size (should not be zero)
    if os.path.getsize(tensor_dataset_output_path) == 0:
        print("File is empty.")
    else:
        print("File size is non-zero.")

        # Load the tensor dataset into a new variable
        try:
            loaded_tensor_dataset = torch.load(tensor_dataset_output_path)
            print("Dataset successfully loaded.")

            # Compare the original and loaded datasets
            if compare_datasets(tensor_dataset, loaded_tensor_dataset):
                print("Success: The loaded dataset is identical to the original.")
            else:
                print("Error: The loaded dataset differs from the original.")

        except Exception as e:
            print("Error loading the dataset:", e)

File found: C:\Users\tatwo\Documents\WELP-PLAPT/data\tensor_dataset.data
File size is non-zero.
Dataset successfully loaded.
Success: The loaded dataset is identical to the original.


### Export Tensor Dataset to Json

In [3]:
import pandas as pd
import torch
from torch.utils.data import TensorDataset, Subset

tensor_dataset = torch.load(tensor_dataset_output_path)

# Initialize lists to store the extracted data
prot_embeddings_list = []
mol_embeddings_list = []
affinity_list = []

# Extract data from the TensorDataset
for prot_emb, mol_emb, affinity in tensor_dataset:
    # Flatten the feature tensors if they are multi-dimensional
    prot_emb_flat = prot_emb.flatten().numpy()
    mol_emb_flat = mol_emb.flatten().numpy()

    # Append the flattened data and the label to the lists
    prot_embeddings_list.append(prot_emb_flat)
    mol_embeddings_list.append(mol_emb_flat)
    affinity_list.append(affinity.item())  # Assuming affinity is a single value

# Create a DataFrame
# You might need to adjust this part depending on the structure of your features
df = pd.DataFrame({
    'prot_embeddings': prot_embeddings_list,
    'mol_embeddings': mol_embeddings_list,
    'affinity': affinity_list
})


In [4]:
df.to_json(tensor_dataset_output_path_json, orient='records')
