# Export ProtBERT and ChemBERTa encoder models to ONNX files

In [1]:
# Load model directly
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, BertModel, BertTokenizer, BertConfig
import torch
import torch.onnx
import onnx

mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert")

In [2]:
# max_prot_input_size = prot_model.config.max_position_embeddings
max_prot_input_size = 3200 #capped at 3200 since tokens longer than 3200 use way too much vram and cause bottleneck
max_mol_input_size = 278 # this is the length of the tokenized longest molecule in the dataset

In [9]:
# Create dummy inputs
dummy_input_mol = mol_tokenizer("CCCCCCCCCCCCCCCCCCCC(=O)O", padding='max_length', max_length=max_mol_input_size, return_tensors="pt")
dummy_input_prot = prot_tokenizer(["M T V P D R S E I A G K W Y V V A L A S N T E F F L R E K D K M K M A M A R I S F L G E D E L K V S Y A V P K P N G C R K W E T T F K K T S D D G E V Y Y S E E A K K K V E V L D T D Y K S Y A V I Y A T R V K D G R T L H M M R L Y S R S P E V S P A A T A I F R K L A G E R N Y T D E M V A M L P R Q E E C T V D E V"], padding='max_length', max_length=max_prot_input_size, return_tensors="pt")

In [6]:
# Set the model to evaluation mode
mol_encoder.eval();
prot_encoder.eval();

The models will save to your documents folder, with the path **"../Documents/WELP-PLAPT/models"**

Might be broken on mac

In [7]:
import os

# Get the user's home directory
home_dir = os.path.expanduser('~')

# Construct the path to the Documents folder
documents_folder = os.path.join(home_dir, 'Documents')

# Construct the full path for the ONNX files
prot_encoder_output_path = os.path.join(documents_folder, "WELP-PLAPT/models", "prot_encoder.onnx")
mol_encoder_output_path = os.path.join(documents_folder, "WELP-PLAPT/models", "mol_encoder.onnx")

# Ensure the Encoders directory exists
os.makedirs(os.path.join(documents_folder, "WELP-PLAPT", "models"), exist_ok=True)

In [32]:
# Export the Molecular Encoder to ONNX
torch.onnx.export(mol_encoder, 
                  args=(dummy_input_mol['input_ids'], dummy_input_mol['attention_mask']), 
                  f=mol_encoder_output_path,
                  input_names=['input_ids', 'attention_mask'],
                  output_names=['output'],
                  opset_version=15,  # or another version depending on compatibility
                  do_constant_folding=True,  # optimize the model
                  # dynamic axes may not be supported in wolfram language mathematica.
                  dynamic_axes={'input_ids': {0: 'batch_size'}, 
                                'attention_mask': {0: 'batch_size'},
                                'output': {0: 'batch_size'}}
                  )

In [33]:
# Export the Protein Encoder to ONNX
torch.onnx.export(prot_encoder, 
                  args=(dummy_input_prot['input_ids'], dummy_input_prot['attention_mask'], dummy_input_prot['token_type_ids']), 
                  f=prot_encoder_output_path,
                  input_names=['input_ids', 'attention_mask', 'token_type_ids'],
                  output_names=['output'],
                  opset_version=15, 
                  do_constant_folding=True,
                  # dynamic axes may not be supported in wolfram language mathematica.
                  dynamic_axes={'input_ids': {0: 'batch_size'}, 
                                'attention_mask': {0: 'batch_size'},
                                'token_type_ids': {0: 'batch_size'},
                                'output': {0: 'batch_size'}}
                  )

### Validation

In [10]:
with torch.no_grad():
    original_prot_output = prot_encoder(**dummy_input_prot)
    
with torch.no_grad():
    original_mol_output = mol_encoder(**dummy_input_mol)

In [None]:
original_prot_output = original_prot_output[0]
original_mol_output = original_mol_output[0]

In [30]:
import onnxruntime as ort

# Function to run ONNX inference
def run_onnx_inference(onnx_model_path, dummy_input):
    session = ort.InferenceSession(onnx_model_path)
    inputs = {session.get_inputs()[0].name: dummy_input['input_ids'].cpu().numpy(),
              session.get_inputs()[1].name: dummy_input['attention_mask'].cpu().numpy()}
    if len(session.get_inputs()) == 3:
        inputs[session.get_inputs()[2].name] = dummy_input['token_type_ids'].cpu().numpy()
    onnx_output = session.run(None, inputs)
    return onnx_output[0]

# Run inference for protein encoder
onnx_prot_output = run_onnx_inference(prot_encoder_output_path, dummy_input_prot)
# Run inference for molecule encoder
onnx_mol_output = run_onnx_inference(mol_encoder_output_path, dummy_input_mol)

In [31]:
import numpy as np

# Function to compare outputs
def compare_outputs(original_output, onnx_output, threshold=1e-3):
    return np.allclose(original_output, onnx_output, atol=threshold)

# Compare the outputs
is_prot_output_similar = compare_outputs(original_prot_output.cpu().numpy(), onnx_prot_output)
print("Protein Encoder Outputs Similar:", is_prot_output_similar)

is_mol_output_similar = compare_outputs(original_mol_output.cpu().numpy(), onnx_mol_output)
print("Molecule Encoder Outputs Similar:", is_mol_output_similar)

if (is_prot_output_similar and is_mol_output_similar):
    print("Passed!")
else:
    print("Failed!")

Protein Encoder Outputs Similar: True
Molecule Encoder Outputs Similar: True
