## Protein Folding with ESMFold and 🤗`transformers`

ESMFold ([paper link](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v2)) is a recently released protein folding model from FAIR. Unlike other protein folding models, it does not require external databases or search tools to predict structures, and is up to 60X faster as a result.

The port to the HuggingFace `transformers` library is even easier to use, as we've removed the dependency on tools like `openfold` - once you `pip install transformers`, you're ready to use this model! 

Note that all the code that follows will be running the model **locally**, rather than calling an external API. This means that no rate limiting applies here - you can predict as many structures as your computer can handle. 

In testing, we found that ESMFold needs about 16-24GB of GPU memory to run well, depending on protein length. This may be too much for the smaller free GPUs on Colab.

First step, make sure you're up to date - you'll need the most recent release of `transformers` and `accelerate`! If you want to visualize your predicted protein structure in the notebook, you should also install py3Dmol. 

In [1]:
!pip install py3Dmol accelerate

Collecting py3Dmol
  Downloading py3Dmol-2.4.2-py2.py3-none-any.whl.metadata (1.9 kB)
Downloading py3Dmol-2.4.2-py2.py3-none-any.whl (7.0 kB)
Installing collected packages: py3Dmol
Successfully installed py3Dmol-2.4.2


We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

## Preparing your model and tokenizer

Now we load our model and tokenizer. If using GPU, use `model.cuda()` to transfer the model to GPU.

In [20]:
from transformers import AutoTokenizer, EsmForProteinFolding
import torch
torch.set_default_dtype(torch.float16)
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True, device_map = 'auto')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some parameters are on the meta device because they were offloaded to the cpu.


## Performance optimizations

Since ESMFold is quite a large model, there are some considerations regarding memory usage and performance.

Firstly, we can optionally convert the language model stem to float16 to improve performance and memory usage when running on a modern GPU. This was used during model training, and so should not make the outputs from the rest of the model invalid.

In [4]:
# Uncomment to switch the stem to float16
model.esm = model.esm.half()

Secondly, you can enable TensorFloat32 computation for a general speedup if your hardware supports it. This line has no effect if your hardware doesn't support it.

In [5]:
import torch

torch.backends.cuda.matmul.allow_tf32 = True

Finally, we can reduce the 'chunk_size' used in the folding trunk. Smaller chunk sizes use less memory, but have slightly worse performance.

In [5]:
# Uncomment this line if your GPU memory is 16GB or less, or if you're folding longer (over 600 or so) sequences
model.trunk.set_chunk_size(64)

## Folding a single chain

First, we tokenize our input. If you've used `transformers` before, proteins are processed like any other input string. Make sure **not** to add special tokens - ESM was trained with them, but ESMFold was trained without them. 

In [6]:
# This is the sequence for human GNAT1, because I worked on it when
# I was a postdoc and so everyone else has to learn to appreciate it too.
# Feel free to substitute your own peptides of interest
# Depending on memory constraints you may wish to use shorter sequences.
test_protein = "MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF"

tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']


If you're using a GPU, you'll need to move the tokenized data to the GPU now.

In [7]:
tokenized_input = tokenized_input.cuda()

With our preparations out of the way, getting your model outputs is as simple as...

In [9]:
import torch

with torch.no_grad():
    output = model(tokenized_input)

Now here's the tricky bit - we convert the model outputs to a PDB file. This will likely be moved to a function in `transformers` in the future, but everything's still quite new, so it lives here for now! This code comes from the original ESMFold repo, and uses some functions from `openfold` that have been ported to `transformers`.

In [10]:
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

In [11]:
pdb = convert_outputs_to_pdb(output)

Now we have our pdb - can we visualize it?

In [12]:
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)
view.addModel("".join(pdb), 'pdb')
view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})

<py3Dmol.view at 0x7f38dafd1160>

Looks good! We can colour it differently, though - our model outputs a `plddt` field containing probabilities for each atom, indicating how confident it is in that part of the structure. In the conversion function above we added the `plddt` field in the `b_factors` argument, so it was included in our `pdb` string. Let's use it so that we can see high- and low-confidence areas of the structure visually!

In [13]:
# The plddt field is scaled from 0-1 on earlier versions of ESMFold but will be updated
# to match AlphaFold's scale of 0-100 in future versions.
# We check here so that this code will work on either:

if torch.max(output['plddt']) <= 1.0:
    vmin = 0.5
    vmax = 0.95
else:
    vmin = 50
    vmax = 95

view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})

<py3Dmol.view at 0x7f38dafd1160>

Blue indicates high confidence, so that's a pretty high-quality prediction! Not too surprising considering GNAT1 was almost certainly in the training data, but nevertheless good to see. Finally, we can write our PDB string out to a file, which you can download and use in other tools.

In [14]:
with open("output_structure.pdb", "w") as f:
    f.write("".join(pdb))

# Export ESM language model + extra computations

In [2]:
import torch
torch.cuda.empty_cache()

In [22]:
import torch
torch.set_default_dtype(torch.float16)
import torch.nn as nn
from torch.onnx import export as onnx_export

class ESMStemWrapper(nn.Module):
    def __init__(self, esm_for_protein_folding):
        super().__init__()
        self.esm = esm_for_protein_folding.esm
        self.af2_to_esm = esm_for_protein_folding.af2_to_esm
        self.esm_s_combine = esm_for_protein_folding.esm_s_combine
        self.esm_s_mlp = esm_for_protein_folding.esm_s_mlp
        self.esm_dict_cls_idx = esm_for_protein_folding.esm_dict_cls_idx
        self.esm_dict_eos_idx = esm_for_protein_folding.esm_dict_eos_idx
        self.esm_dict_padding_idx = esm_for_protein_folding.esm_dict_padding_idx

    def forward(self, input_ids, attention_mask):
        # Convert input_ids to ESM indices
        esmaa = self.af2_to_esm[input_ids + 1].masked_fill(attention_mask != 1, 0)

        # Add BOS and EOS tokens
        batch_size = esmaa.shape[0]
        bos = esmaa.new_full((batch_size, 1), self.esm_dict_cls_idx)
        eos = esmaa.new_full((batch_size, 1), self.esm_dict_padding_idx)
        esmaa = torch.cat([bos, esmaa, eos], dim=1)
        esmaa[range(batch_size), (esmaa != 1).sum(1)] = self.esm_dict_eos_idx

        # Compute ESM representations
        esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
        esm_s = torch.stack(esm_hidden_states, dim=2)[:, 1:-1]  # Remove BOS and EOS

        # Combine ESM representations
        esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
        s_s_0 = self.esm_s_mlp(esm_s)

        return s_s_0

In [25]:
# Initialize the ESM stem wrapper
esm_stem_wrapper = ESMStemWrapper(model)

# Example inputs
input_ids = torch.randint(0, 20, (1, 100), device="cuda")  # Example input (batch_size=1, seq_len=100)
attention_mask = torch.ones_like(input_ids, device="cuda")  # Example input

# Export to ONNX with dynamic batch size
onnx_export(
    esm_stem_wrapper,
    (input_ids, attention_mask),
    "esm.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["s_s_0"],
    opset_version=17,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "attention_mask": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "s_s_0": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
    },
)

In [24]:
import h5py
import torch

esm_stem_wrapper = ESMStemWrapper(model)

# Example inputs for ESM stem
input_ids = torch.randint(0, 20, (1, 100), device="cuda")  # Example input (batch_size=1, seq_len=100)
attention_mask = torch.ones_like(input_ids, device="cuda")  # Example input

# Run the ESM stem to get outputs
with torch.no_grad():
    s_s_0 = esm_stem_wrapper(input_ids, attention_mask)

# Save inputs and outputs to HDF5
with h5py.File("esm_stem_test_data.h5", "w") as f:
    f.create_dataset("input_ids", data=input_ids.cpu().numpy())
    f.create_dataset("attention_mask", data=attention_mask.cpu().numpy())
    f.create_dataset("s_s_0", data=s_s_0.cpu().numpy())

# Export trunk

In [1]:
import torch
torch.set_default_dtype(torch.float16)
import torch.nn as nn
from torch.onnx import export
from typing import Optional

# Wrapper for Pairwise Positional Embedding
class PairwisePositionalEmbeddingWrapper(nn.Module):
    def __init__(self, trunk):
        super().__init__()
        self.pairwise_positional_embedding = trunk.pairwise_positional_embedding

    def forward(self, z, residx, mask):
        return self.pairwise_positional_embedding(residx, mask=mask) + z

# Wrapper for a Single Triangular Self-Attention Block
class TriangularSelfAttentionBlockWrapper(nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks

    def forward(self, s, z, mask, residx):
        for block in self.blocks:
            s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=30)
        return s,z

# Wrapper for the Structure Module
class StructureModuleWrapper(nn.Module):
    def __init__(self, structure_module):
        super().__init__()
        self.structure_module = structure_module

    def forward(self, s, z, aatype, mask):
        return self.structure_module({"single": s, "pair": z}, aatype, mask.float())

# Export Pairwise Positional Embedding
def export_pairwise_positional_embedding(trunk, output_path, seq_length, batch_size=1):
    wrapper = PairwisePositionalEmbeddingWrapper(trunk)
    z = torch.randn(batch_size, seq_length, seq_length, trunk.config.pairwise_state_dim)
    residx = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)
    mask = torch.ones(batch_size, seq_length, dtype=torch.bool)

    export(
        wrapper,
        (z, residx, mask),
        output_path,
        input_names=["z", "residx", "mask"],
        output_names=["z_out"],
        dynamic_axes={
            "z": {0: "batch_size", 1: "seq_length", 2: "seq_length"},
            "residx": {0: "batch_size", 1: "seq_length"},
            "mask": {0: "batch_size", 1: "seq_length"},
            "z_out": {0: "batch_size", 1: "seq_length", 2: "seq_length"},
        },
        opset_version=17,
    )

# Export a Single Triangular Self-Attention Block
def export_triangular_self_attention_block(trunk, output_path, seq_length, batch_size=1):
    wrapper = TriangularSelfAttentionBlockWrapper(trunk.blocks)
    s = torch.randn(batch_size, seq_length, trunk.blocks[0].config.sequence_state_dim)
    z = torch.randn(batch_size, seq_length, seq_length, trunk.blocks[0].config.pairwise_state_dim)
    mask = torch.ones(batch_size, seq_length)
    residx = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)

    export(
        wrapper,
        (s, z, mask, residx),
        output_path,
        input_names=["s", "z", "mask", "residx"],
        output_names=["s_out", "z_out"],
        dynamic_axes={
            "s": {0: "batch_size", 1: "seq_length"},
            "z": {0: "batch_size", 1: "seq_length", 2: "seq_length"},
            "mask": {0: "batch_size", 1: "seq_length"},
            "residx": {0: "batch_size", 1: "seq_length"},
            "s_out": {0: "batch_size", 1: "seq_length"},
            "z_out": {0: "batch_size", 1: "seq_length", 2: "seq_length"},
        },
        opset_version=17,
    )

# Export the Structure Module
def export_structure_module(trunk, output_path, seq_length, batch_size=1):
    wrapper = StructureModuleWrapper(trunk.structure_module)
    s = torch.randn(batch_size, seq_length, trunk.config.structure_module.sequence_dim)
    z = torch.randn(batch_size, seq_length, seq_length, trunk.config.structure_module.pairwise_dim)
    aatype = torch.randint(0, 20, (batch_size, seq_length))
    mask = torch.ones(batch_size, seq_length, dtype=torch.bool)

    export(
        wrapper,
        (s, z, aatype, mask),
        output_path,
        input_names=["s", "z", "aatype", "mask"],
        output_names=["output"],
        dynamic_axes={
            "s": {0: "batch_size", 1: "seq_length"},
            "z": {0: "batch_size", 1: "seq_length", 2: "seq_length"},
            "aatype": {0: "batch_size", 1: "seq_length"},
            "mask": {0: "batch_size", 1: "seq_length"},
            "output": {0: "batch_size", 1: "seq_length"},
        },
        opset_version=17,
    )
    
# Load the model
from transformers import EsmForProteinFolding
import torch
torch.set_default_dtype(torch.float16)
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True, device_map = 'auto')
trunk = model.trunk

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some parameters are on the meta device because they were offloaded to the cpu.


In [None]:
# Define sequence length and batch size
seq_length = 10
batch_size = 1
# Export components
#export_pairwise_positional_embedding(trunk, "pairwise_positional_embedding.onnx", seq_length, batch_size)
export_triangular_self_attention_block(trunk, "triangular_self_attention_blocks.onnx", seq_length, batch_size)
#export_structure_module(trunk, "structure_module.onnx", seq_length, batch_size)

  if sequence_state_dim != self.config.sequence_state_dim:
  if pairwise_state_dim != self.config.pairwise_state_dim:
  if batch_dim != pairwise_state.shape[0]:
  if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
  if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
  orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
  if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
  return t[i : i + chunk_size] if t.shape[0] != 1 else t


In [25]:
import torch
torch.set_default_dtype(torch.float16)
import torch.nn as nn
from torch.onnx import export as onnx_export

class TrunkWrapper(nn.Module):
    def __init__(self, esm_for_protein_folding):
        super().__init__()
        self.trunk = esm_for_protein_folding.trunk

    def forward(self, s_s_0, input_ids, attention_mask, position_ids):
        # Initialize pairwise features
        B, L = s_s_0.shape[:2]
        s_z_0 = s_s_0.new_zeros(B, L, L, self.trunk.config.pairwise_state_dim)

        # Run the trunk
        structure = self.trunk(s_s_0, s_z_0, input_ids, position_ids, attention_mask, no_recycles=0)
        return structure

In [3]:
# Initialize the trunk wrapper
trunk_wrapper = TrunkWrapper(model)

# Example inputs
input_ids = torch.randint(0, 20, (1, 100), device="cuda")  # Example input (batch_size=1, seq_len=100)
s_s_0 = torch.randn(1, 100, 1024, device="cuda")  # Example input (batch_size=1, seq_len=100)
position_ids = torch.arange(100, device="cuda").unsqueeze(0)  # Example input
attention_mask = torch.ones_like(input_ids, device="cuda")  # Example input

# Export to ONNX with dynamic batch size
onnx_export(
    trunk_wrapper,
    (s_s_0, input_ids, attention_mask, position_ids),
    "trunk.onnx",
    input_names=["s_s_0", "input_ids", "attention_mask", "position_ids"],
    output_names=["structure"],
    opset_version=17,
    dynamic_axes={
        "s_s_0": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "input_ids": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "attention_mask": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "position_ids": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "structure": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
    },
)

  if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
  if mask is not None and residue_index.shape != mask.shape:
  if sequence_state_dim != self.config.sequence_state_dim:
  if pairwise_state_dim != self.config.pairwise_state_dim:
  if batch_dim != pairwise_state.shape[0]:
  if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
  if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
  if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
  return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
  torch.tensor(
  torch.tensor(
  torch.tensor(
  torch.tensor(
  if t.shape[-2:] != (4, 4):


: 

In [31]:
# Initialize the trunk wrapper
import numpy as np
trunk_wrapper = TrunkWrapper(model)
# Example inputs for trunk
position_ids = torch.arange(100, device="cuda").unsqueeze(0)  # Example input

# Run the trunk to get outputs
with torch.no_grad():
    structure = trunk_wrapper(s_s_0, input_ids, attention_mask, position_ids)

# Save inputs and outputs to HDF5
with h5py.File("trunk_test_data.h5", "w") as f:
    f.create_dataset("s_s_0", data=s_s_0.cpu().numpy())
    f.create_dataset("input_ids", data=input_ids.cpu().numpy())
    f.create_dataset("attention_mask", data=attention_mask.cpu().numpy())
    f.create_dataset("position_ids", data=position_ids.cpu().numpy())
    for k,v in structure.items():
        f.create_dataset(k, data=v.cpu().numpy())

# Heads

In [29]:
import torch
torch.set_default_dtype(torch.float16)
import torch.nn as nn
from torch.onnx import export as onnx_export
class HeadsWrapper(nn.Module):
    def __init__(self, esm_for_protein_folding):
        super().__init__()
        self.distogram_head = esm_for_protein_folding.distogram_head
        self.ptm_head = esm_for_protein_folding.ptm_head
        self.lm_head = esm_for_protein_folding.lm_head
        self.lddt_head = esm_for_protein_folding.lddt_head

    def forward(self, s_s, s_z, structure):
        # Compute outputs from the heads
        disto_logits = self.distogram_head(s_z)
        disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2

        lm_logits = self.lm_head(s_s)

        lddt_head = self.lddt_head(structure).reshape(
            structure.shape[0], s_s.shape[0], s_s.shape[1], -1, 50
        )

        ptm_logits = self.ptm_head(s_z)

        return {
            "distogram_logits": disto_logits,
            "lm_logits": lm_logits,
            "lddt_head": lddt_head,
            "ptm_logits": ptm_logits,
        }

In [30]:
# Initialize the heads wrapper
heads_wrapper = HeadsWrapper(model)

# Example inputs
structure =  torch.randn(1, 100, 384, device="cuda")  # Example input (batch_size=1, seq_len=100)
# Export to ONNX with dynamic batch size
onnx_export(
    heads_wrapper,
    (s_s, s_z, structure),
    "heads.onnx",
    input_names=["s_s", "s_z", "structure"],
    output_names=["distogram_logits", "lm_logits", "lddt_head", "ptm_logits"],
    opset_version=17,
    dynamic_axes={
        "s_s": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "s_z": {0: "batch_size", 1: "seq_len", 2: "seq_len"},  # Dynamic batch size and sequence length
        "structure": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "distogram_logits": {0: "batch_size", 1: "seq_len", 2: "seq_len"},  # Dynamic batch size and sequence length
        "lm_logits": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "lddt_head": {0: "batch_size", 1: "seq_len"},  # Dynamic batch size and sequence length
        "ptm_logits": {0: "batch_size", 1: "seq_len", 2: "seq_len"},  # Dynamic batch size and sequence length
    },
)

In [31]:
# Run the heads to get outputs
with torch.no_grad():
    outputs = heads_wrapper(s_s, s_z, structure)

# Save inputs and outputs to HDF5
with h5py.File("heads_test_data.h5", "w") as f:
    f.create_dataset("s_s", data=s_s.cpu().numpy())
    f.create_dataset("s_z", data=s_z.cpu().numpy())
    f.create_dataset("structure", data=structure.cpu().numpy())
    f.create_dataset("distogram_logits", data=outputs["distogram_logits"].cpu().numpy())
    f.create_dataset("lm_logits", data=outputs["lm_logits"].cpu().numpy())
    f.create_dataset("lddt_head", data=outputs["lddt_head"].cpu().numpy())
    f.create_dataset("ptm_logits", data=outputs["ptm_logits"].cpu().numpy())

# Test

In [47]:
import h5py
import onnxruntime as ort
import numpy as np

# Load the ONNX model
esm_stem_session = ort.InferenceSession("esm.onnx")

# Load inputs and outputs from HDF5
with h5py.File("esm_stem_test_data.h5", "r") as f:
    input_ids = f["input_ids"][:]
    attention_mask = f["attention_mask"][:]
    expected_s_s_0 = f["s_s_0"][:]

# Run the ONNX model
esm_stem_inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}
esm_stem_outputs = esm_stem_session.run(None, esm_stem_inputs)
s_s_0 = esm_stem_outputs[0]

# Compare outputs
assert np.allclose(s_s_0, expected_s_s_0, atol=1e-4), "ESM stem outputs do not match!"
print("ESM stem test passed!")

: 

In [None]:
# Load the ONNX model
import h5py
import onnxruntime as ort
import numpy as np

trunk_session = ort.InferenceSession("trunk.onnx")

# Load inputs and outputs from HDF5
with h5py.File("trunk_test_data.h5", "r") as f:
    s_s_0 = f["s_s_0"][:]
    input_ids = f["input_ids"][:]
    attention_mask = f["attention_mask"][:]
    position_ids = f["position_ids"][:]
    expected_structure = f["structure"][:]

# Run the ONNX model
trunk_inputs = {
    "s_s_0": s_s_0,
    "input_ids": input_ids,
    "attention_mask": attention_mask,
    "position_ids": position_ids,
}
trunk_outputs = trunk_session.run(None, trunk_inputs)
structure = trunk_outputs[0]

# Compare outputs
assert np.allclose(structure, expected_structure, atol=1e-5), "Trunk outputs do not match!"
print("Trunk test passed!")

In [32]:
# Load the ONNX model
import h5py
import onnxruntime as ort
import numpy as np
structure_module_session = ort.InferenceSession("structure_module.onnx")

# Load inputs and outputs from HDF5
with h5py.File("structure_module_test_data.h5", "r") as f:
    s_s = f["s_s"][:]
    s_z = f["s_z"][:]
    input_ids = f["input_ids"][:]
    attention_mask = f["attention_mask"][:]
    expected_structure = f["structure"][:]

# Run the ONNX model
structure_module_inputs = {
    "s_s": s_s,
    "s_z": s_z,
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}
structure_module_outputs = structure_module_session.run(None, structure_module_inputs)
structure = structure_module_outputs[0]

# Compare outputs
assert np.allclose(structure, expected_structure, atol=1e-5), "Structure module outputs do not match!"
print("Structure module test passed!")

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from structure_module.onnx failed:Node (/structure_module/ipa/Transpose_6) Op (Transpose) [TypeInferenceError] Invalid attribute perm {0, -3, -1, -4, -2}, input shape = {0, 0, 0, 0, 0}

In [1]:
# Load the ONNX model
import h5py
import onnxruntime as ort
import numpy as np
heads_session = ort.InferenceSession("heads.onnx")

# Load inputs and outputs from HDF5
with h5py.File("heads_test_data.h5", "r") as f:
    s_s = f["s_s"][:]
    s_z = f["s_z"][:]
    structure = f["structure"][:]
    expected_distogram_logits = f["distogram_logits"][:]
    expected_lm_logits = f["lm_logits"][:]
    expected_lddt_head = f["lddt_head"][:]
    expected_ptm_logits = f["ptm_logits"][:]

# Run the ONNX model
heads_inputs = {
    "s_s": s_s,
    "s_z": s_z,
    "structure": structure,
}
heads_outputs = heads_session.run(None, heads_inputs)

# Compare outputs
np.testing.assert_allclose(heads_outputs[0], expected_distogram_logits, rtol=1e-03, atol=1e-03) 
np.testing.assert_allclose(heads_outputs[1], expected_lm_logits, rtol=1e-03, atol=1e-3)
np.testing.assert_allclose(heads_outputs[2], expected_lddt_head, rtol=1e-02, atol=1e-3)
np.testing.assert_allclose(heads_outputs[3], expected_ptm_logits, rtol=1e-03, atol=1e-3)
print("Heads test passed!")

AssertionError: 
Not equal to tolerance rtol=0.01, atol=0.001

Mismatched elements: 894 / 185000 (0.483%)
Max absolute difference among violations: 0.1719
Max relative difference among violations: 20.58
 ACTUAL: array([[[[[-1.5477e+01, -4.9094e+01, -2.9516e+01, ...,  1.5844e+01,
            1.4602e+01,  1.0148e+01],
          [-1.1570e+01, -3.0859e+00, -1.9482e+00, ...,  2.2078e+01,...
 DESIRED: array([[[[[-1.5477e+01, -4.9094e+01, -2.9516e+01, ...,  1.5844e+01,
            1.4602e+01,  1.0148e+01],
          [-1.1570e+01, -3.0879e+00, -1.9697e+00, ...,  2.2078e+01,...

In [None]:
import onnxruntime as ort

# Load the ONNX models
esm_stem_session = ort.InferenceSession("esm.onnx")
trunk_session = ort.InferenceSession("trunk.onnx")
structure_module_session = ort.InferenceSession("structure_module.onnx")
heads_session = ort.InferenceSession("heads.onnx")

In [None]:
import numpy as np

# Example inputs (batch_size=1, seq_len=100)
input_ids = np.random.randint(0, 20, (1, 100)).astype(np.int64)  # Token IDs
attention_mask = np.ones((1, 100), dtype=np.int64)  # Attention mask

# Run ESM stem
esm_stem_inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}
esm_stem_outputs = esm_stem_session.run(None, esm_stem_inputs)
s_s_0 = esm_stem_outputs[0]  # Output: s_s_0

# Prepare inputs for the trunk
position_ids = np.arange(100, dtype=np.int64).reshape(1, -1)  # Position IDs

# Run trunk
trunk_inputs = {
    "s_s_0": s_s_0,
    "input_ids": input_ids,
    "attention_mask": attention_mask,
    "position_ids": position_ids,
}
trunk_outputs = trunk_session.run(None, trunk_inputs)
structure = trunk_outputs[0]  # Output: structure

# Prepare inputs for the structure module
s_s = np.random.randn(1, 100, 1024).astype(np.float32)  # Example s_s
s_z = np.random.randn(1, 100, 100, 128).astype(np.float32)  # Example s_z

# Run structure module
structure_module_inputs = {
    "s_s": s_s,
    "s_z": s_z,
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}
structure_module_outputs = structure_module_session.run(None, structure_module_inputs)
structure = structure_module_outputs[0]  # Output: structure

# Prepare inputs for the heads
structure = {
    "states": np.random.randn(1, 100, 1024).astype(np.float32),  # Example states
}

# Run heads
heads_inputs = {
    "s_s": s_s,
    "s_z": s_z,
    "structure": structure["states"],
}
heads_outputs = heads_session.run(None, heads_inputs)

# Extract outputs
distogram_logits = heads_outputs[0]  # Output: distogram_logits
lm_logits = heads_outputs[1]  # Output: lm_logits
lddt_head = heads_outputs[2]  # Output: lddt_head
ptm_logits = heads_outputs[3]  # Output: ptm_logits

# Example: Compute pLDDT score
plddt = np.mean(lddt_head, axis=-1)  # Average over bins
print("pLDDT:", plddt)