## Protein Folding with ESMFold

In [4]:
import openfold
print(openfold.__file__)

/home/siria/anaconda3/envs/esmfold2/lib/python3.9/site-packages/openfold/__init__.py


In [None]:
!pip install onnxruntime
!pip install onnx --upgrade
!pip install biotite
!pip install scipy
!pip install einops
!pip install -e .

In [None]:
!pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
!pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'

In [3]:
!pip install h5py

Collecting h5py
  Downloading h5py-3.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Downloading h5py-3.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: h5py
Successfully installed h5py-3.13.0


In [1]:
import torch
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval().cuda()

  from .autonotebook import tqdm as notebook_tqdm
  from scipy.stats import truncnorm


In [None]:
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# Multimer prediction can be done with chains separated by ':'
model.cuda()
with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT
# 88.3

TypeError: One of data, shape or dtype must be specified

In [None]:
# Create dummy inputs
batch_size, seq_len = 1, 50
aa = torch.randint(0, 20, (batch_size, seq_len), dtype=torch.long).cuda()
out=model(aa)
out.keys()

dict_keys(['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 's_s', 's_z', 'distogram_logits', 'lm_logits', 'aatype', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', 'lddt_head', 'plddt', 'ptm_logits', 'ptm', 'aligned_confidence_probs', 'predicted_aligned_error', 'max_predicted_aligned_error'])

In [None]:
import torch
import torch.nn as nn
from typing import Optional, Dict
from openfold.data.data_transforms import make_atom14_masks
from esm.esmfold.v1.categorical_mixture import categorical_lddt
from openfold.utils.loss import compute_predicted_aligned_error, compute_tm


class ESMFoldONNXWrapper(nn.Module):
    def __init__(self, esmfold_model):
        super().__init__()
        self.esmfold = esmfold_model

        # Replace the trunk with a dummy trunk if num_folding_blocks is 0
        self.trunk = self.esmfold.trunk

    def _af2_idx_to_esm_idx(self, aa, mask):
        aa = (aa + 1).masked_fill(mask != 1, 0)
        return self.esmfold.af2_to_esm[aa]

    def _mask_inputs_to_esm(self, esmaa, pattern):
        """
        Replace `pattern == 1` with ONNX-supported operations.
        """
        new_esmaa = esmaa.clone()
        mask = pattern * 1  # Convert boolean mask to integer (1s and 0s)
        new_esmaa[mask == 1] = self.esmfold.esm_dict.mask_idx  # Supported by ONNX
        return new_esmaa

    def _compute_language_model_representations(
        self, esmaa: torch.Tensor
    ) -> torch.Tensor:
        """Adds bos/eos tokens for the language model, since the structure module doesn't use these."""
        batch_size = esmaa.size(0)

        bosi, eosi = self.esmfold.esm_dict.cls_idx, self.esmfold.esm_dict.eos_idx
        bos = esmaa.new_full((batch_size, 1), bosi)
        eos = esmaa.new_full((batch_size, 1), self.esmfold.esm_dict.padding_idx)
        esmaa = torch.cat([bos, esmaa, eos], dim=1)
        # Use the first padding index as eos during inference.
        esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi

        res = self.esmfold.esm(
            esmaa,
            repr_layers=range(self.esmfold.esm.num_layers + 1),
            need_head_weights=self.esmfold.cfg.use_esm_attn_map,
        )
        esm_s = torch.stack(
            [v for _, v in sorted(res["representations"].items())], dim=2
        )
        esm_s = esm_s[:, 1:-1]  # B, L, nLayers, C
        esm_z = (
            res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :]
            if self.esmfold.cfg.use_esm_attn_map
            else None
        )
        return esm_s, esm_z

    def _mask_inputs_to_esm(self, esmaa, pattern):
        new_esmaa = esmaa.clone()
        mask = pattern * 1  # Convert boolean mask to integer (1s and 0s)
        new_esmaa[mask == 1] = self.esm_dict.mask_idx  # Now this is supported
        return new_esmaa

    def forward(
        self,
        aa: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        residx: Optional[torch.Tensor] = None,
        masking_pattern: Optional[torch.Tensor] = None,
        num_recycles: Optional[int] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Full inference pipeline for ONNX export.
        """
        if mask is None:
            mask = torch.ones_like(aa)

        B = aa.shape[0]
        L = aa.shape[1]
        device = aa.device

        if residx is None:
            residx = torch.arange(L, device=device).expand_as(aa)

        # === ESM ===
        esmaa = self._af2_idx_to_esm_idx(aa, mask)

        if masking_pattern is not None:
            esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)

        esm_s, esm_z = self._compute_language_model_representations(esmaa)

        # Convert esm_s to the precision used by the trunk and
        # the structure module. These tensors may be a lower precision if, for example,
        # we're running the language model in fp16 precision.
        esm_s = esm_s.to(self.esmfold.esm_s_combine.dtype)
        esm_s = esm_s.detach()

        # === preprocessing ===
        esm_s = (self.esmfold.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)

        s_s_0 = self.esmfold.esm_s_mlp(esm_s)
        if self.esmfold.cfg.use_esm_attn_map:
            esm_z = esm_z.to(self.esmfold.esm_s_combine.dtype)
            esm_z = esm_z.detach()
            s_z_0 = self.esmfold.esm_z_mlp(esm_z)
        else:
            s_z_0 = s_s_0.new_zeros(B, L, L, self.esmfold.cfg.trunk.pairwise_state_dim)

        s_s_0 += self.esmfold.embedding(aa)
        print(s_s_0.shape)
        structure: dict = self.esmfold.trunk(
            s_s_0, s_z_0, aa, residx, mask, no_recycles=num_recycles
        )
        # Documenting what we expect:
        structure = {
            k: v
            for k, v in structure.items()
            if k
            in [
                "s_z",
                "s_s",
                "frames",
                "sidechain_frames",
                "unnormalized_angles",
                "angles",
                "positions",
                "states",
            ]
        }
        print(structure["s_s"].shape)
        disto_logits = self.esmfold.distogram_head(structure["s_z"])
        disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
        structure["distogram_logits"] = disto_logits

        lm_logits = self.esmfold.lm_head(structure["s_s"])
        structure["lm_logits"] = lm_logits

        structure["aatype"] = aa
        make_atom14_masks(structure)

        for k in [
            "atom14_atom_exists",
            "atom37_atom_exists",
        ]:
            structure[k] *= mask.unsqueeze(-1)
        structure["residue_index"] = residx

        lddt_head = self.esmfold.lddt_head(structure["states"]).reshape(
            structure["states"].shape[0], B, L, -1, self.esmfold.lddt_bins
        )
        structure["lddt_head"] = lddt_head
        plddt = categorical_lddt(lddt_head[-1], bins=self.esmfold.lddt_bins)
        structure["plddt"] = (
            100 * plddt
        )  # we predict plDDT between 0 and 1, scale to be between 0 and 100.

        ptm_logits = self.esmfold.ptm_head(structure["s_z"])

        seqlen = mask.type(torch.int64).sum(1)
        structure["ptm_logits"] = ptm_logits
        structure["ptm"] = torch.stack(
            [
                compute_tm(
                    batch_ptm_logits[None, :sl, :sl],
                    max_bins=31,
                    no_bins=self.esmfold.distogram_bins,
                )
                for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
            ]
        )
        structure.update(
            compute_predicted_aligned_error(
                ptm_logits, max_bin=31, no_bins=self.esmfold.distogram_bins
            )
        )

        return structure

In [10]:
# Load the pretrained ESMFold model
esmfold_model = model

# Wrap the model for ONNX export
wrapped_model = ESMFoldONNXWrapper(esmfold_model)
wrapped_model.eval().cuda()


ESMFoldONNXWrapper(
  (esmfold): ESMFold(
    (esm): ESM2(
      (embed_tokens): Embedding(33, 320, padding_idx=1)
      (layers): ModuleList(
        (0): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=320, out_features=320, bias=True)
            (v_proj): Linear(in_features=320, out_features=320, bias=True)
            (q_proj): Linear(in_features=320, out_features=320, bias=True)
            (out_proj): Linear(in_features=320, out_features=320, bias=True)
            (rot_emb): RotaryEmbedding()
          )
          (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (1): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(i

In [None]:
from esm.esmfold.v1.misc import batch_encode_sequences, collate_dense_tensors
import typing as T
sequences = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
def infer(
        sequences: T.Union[str, T.List[str]],
        residx=None,
        masking_pattern: T.Optional[torch.Tensor] = None,
        num_recycles: T.Optional[int] = 0,
        residue_index_offset: T.Optional[int] = 512,
        chain_linker: T.Optional[str] = "G" * 25,
    ):
        """Runs a forward pass given input sequences.

        Args:
            sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
                each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
            residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
            masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
                as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
                different masks are provided.
            num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
                recycles (cfg.trunk.max_recycles), which is 4.
            residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
                single chain predictions. Default: 512.
            chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
                predictions. Default: length-25 poly-G ("G" * 25).
        """
        if isinstance(sequences, str):
            sequences = [sequences]

        aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
            sequences, residue_index_offset, chain_linker
        )

        if residx is None:
            residx = _residx
        elif not isinstance(residx, torch.Tensor):
            residx = collate_dense_tensors(residx)

        aatype, mask, residx, linker_mask = map(
            lambda x: x.to("cuda"), (aatype, mask, residx, linker_mask)
        )
        print(wrapped_model(aatype))
        
        torch.onnx.export(
            wrapped_model,
            aatype,
            "esmfold_full_inference.onnx",
            input_names=["aa"],
            dynamic_axes={
                "aa": {1: "seq_len"}
            },
            verbose=True,
        )
        output = wrapped_model(
            aatype,
            mask=mask,
            residx=residx,
            masking_pattern=masking_pattern,
            num_recycles=num_recycles,
        )

        return output
infer(sequences)

torch.Size([1, 65, 1024])
torch.Size([1, 65, 1024])
{'frames': tensor([[[[  0.8384,  -0.0645,  -0.5042,  ...,   5.0050,  10.6682, -14.5706],
          [  0.7023,   0.1900,  -0.6542,  ...,   5.6512,   9.9312, -12.1012],
          [  0.2541,   0.4043,  -0.5546,  ...,   5.8012,   9.4977, -10.0755],
          ...,
          [  0.6699,   0.4468,  -0.4067,  ...,   6.1647,  -5.6255,   5.2155],
          [  0.9002,  -0.0832,  -0.2578,  ...,   7.4492,  -6.5384,   6.5318],
          [  0.8823,  -0.2721,  -0.0890,  ...,   9.7180,  -8.7572,   8.0467]]],


        [[[  0.8299,  -0.0388,  -0.4311,  ...,   7.0072,  11.4975, -13.9104],
          [  0.4702,   0.1923,  -0.7234,  ...,   7.7972,  10.6426, -11.6366],
          [ -0.0757,   0.2709,  -0.4249,  ...,   6.7236,  10.8560, -10.6881],
          ...,
          [  0.8250,  -0.1426,  -0.3811,  ...,   6.8207,  -5.9426,   5.2841],
          [  0.6914,  -0.0825,  -0.5306,  ...,   7.9848,  -4.7659,   6.4169],
          [  0.7530,  -0.5792,  -0.2185,  ...

  if not padding_mask.any():
  assert embed_dim == self.embed_dim
  assert list(query.size()) == [tgt_len, bsz, embed_dim]
  if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
  assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
  assert residue_index.shape == mask.shape
  if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
  if((rots.shape != trans.shape[:-1]) or
  if(t.shape[-2:] != (4, 4)):


torch.Size([1, 65, 1024])


  restype_atom14_to_atom37 = torch.tensor(
  restype_atom37_to_atom14 = torch.tensor(
  restype_atom14_mask = torch.tensor(
  for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  clipped_n = max(n, 19)
  return per_alignment[tuple(argmax)]
  return per_alignment[tuple(argmax)]
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [1]:
import torch
import torch.nn as nn
from typing import Dict


class ESMLanguageModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.esmfold = model

    def _mask_inputs_to_esm(self, esmaa, mask):
        """
        Replace `pattern == 1` with ONNX-supported operations.
        """
        new_esmaa = esmaa.clone()
        new_esmaa[mask==1] = self.esmfold.esm_dict.mask_idx  # Supported by ONNX
        return new_esmaa

    def forward(
        self,
        aa: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for the ESM language model and preprocessing.
        """
        
        mask = torch.ones_like(aa)

        # === ESM Language Model ===
        esmaa = self.esmfold._af2_idx_to_esm_idx(aa, mask)

        esm_s, esm_z = self.esmfold._compute_language_model_representations(esmaa)
        esm_s = esm_s.to(self.esmfold.esm_s_combine.dtype)
        esm_s = (self.esmfold.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)

        # === Preprocessing ===
        s_s_0 = self.esmfold.esm_s_mlp(esm_s)
        if self.esmfold.cfg.use_esm_attn_map:
            esm_z = esm_z.to(self.esmfold.esm_s_combine.dtype)
            s_z_0 = self.esmfold.esm_z_mlp(esm_z)
        else:
            s_z_0 = s_s_0.new_zeros(aa.shape[0], aa.shape[1], aa.shape[1],
                                   self.esmfold.cfg.trunk.pairwise_state_dim)

        s_s_0 += self.esmfold.embedding(aa)

        return {
            "s_s_0": s_s_0,
            "s_z_0": s_z_0,
        }

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import esm

model = esm.pretrained.esmfold_structure_module_only_8M()

# Create dummy inputs
batch_size, seq_len = 1, 1024
aa = torch.randint(0, 20, (batch_size, seq_len), dtype=torch.long)

# Initialize the wrapper
esm_lm_wrapper = ESMLanguageModelWrapper(model)
esm_lm_wrapper.eval().to("cpu")
esm_lm_wrapper(aa)
# Export to ONNX
torch.onnx.export(
    esm_lm_wrapper,
    aa,
    "esm_lm.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=["aa"],
    output_names=["s_s_0", "s_z_0"],
    dynamic_axes={
        "aa": {0:"batch",1: "seq_len"}
    },
    opset_version=17
)

  from scipy.stats import truncnorm
  if not padding_mask.any():
  assert embed_dim == self.embed_dim
  assert list(query.size()) == [tgt_len, bsz, embed_dim]
  if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
  assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]


In [6]:
import onnxruntime as ort
import numpy as np

# Load the ONNX model
onnx_session = ort.InferenceSession("esm_lm.onnx")
esm_lm_wrapper = ESMLanguageModelWrapper(model)
esm_lm_wrapper.eval().to("cpu")

batch_size, seq_len = 1, 1024
aa = torch.randint(0, 20, (batch_size, seq_len), dtype=torch.long)
# Prepare the input for the ONNX model
onnx_input = {onnx_session.get_inputs()[0].name: aa.to("cpu").numpy()}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)
# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output = esm_lm_wrapper(aa)
pytorch_output = [pytorch_output["s_s_0"].to("cpu").numpy(),pytorch_output["s_z_0"].to("cpu").numpy()]

# Calculate the absolute difference between the outputs
diff = np.abs(pytorch_output[0] - onnx_output[0])

# Print statistics about the difference
print(f"Max difference: {np.max(diff)}")
print(f"Mean difference: {np.mean(diff)}")
print(f"Number of differences > 1e-5: {np.sum(diff > 1e-4)}")

diff = np.abs(pytorch_output[1] - onnx_output[1])

# Print statistics about the difference
print(f"Max difference: {np.max(diff)}")
print(f"Mean difference: {np.mean(diff)}")
print(f"Number of differences > 1e-5: {np.sum(diff > 1e-4)}")
    

Max difference: 4.00543212890625e-05
Mean difference: 4.804411219083704e-06
Number of differences > 1e-5: 0
Max difference: 0.00012111663818359375
Mean difference: 2.462937231939577e-07
Number of differences > 1e-5: 1


In [5]:
import torch
import torch.nn as nn

class Distogram(nn.Module):
    def __init__(self, min_bin, max_bin, num_bins):
        super().__init__()
        self.min_bin = min_bin
        self.max_bin = max_bin
        self.num_bins = num_bins

    def forward(self, coords):
        boundaries = torch.linspace(self.min_bin, self.max_bin, self.num_bins - 1, device=coords.device)
        boundaries = boundaries**2
        N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
        b = CA - N
        c = C - CA
        a = b.cross(c, dim=-1)
        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
        dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
        bins = torch.sum(dists > boundaries, dim=-1)
        return bins

# Export Distogram
distogram = Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)
coords = torch.randn(1, 100, 3, 3)  # Example input

class trunkWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.recycle_s_norm = model.trunk.recycle_s_norm 
        self.recycle_z_norm = model.trunk.recycle_z_norm 
        self.recycle_disto = model.trunk.recycle_disto 
        self.pairwise_positional_embedding=model.trunk.pairwise_positional_embedding
        self.structure_module = model.trunk.structure_module
        self.trunk2sm_s = model.trunk.trunk2sm_s
        self.trunk2sm_z = model.trunk.trunk2sm_z
        self.distogram = Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)

    
    def forward(self,true_aa, s_s_0, s_z_0, recycle_s, recycle_z, recycle_bins, residx, mask):
        def trunk_iter(s, z, residx, mask):
            z = z + self.pairwise_positional_embedding(residx, mask=mask)
            return s, z

        recycle_s = self.recycle_s_norm(recycle_s.detach())
        recycle_z = self.recycle_z_norm(recycle_z.detach())
        recycle_z += self.recycle_disto(recycle_bins.detach())
        s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
        structure = self.structure_module(
            {"single": self.trunk2sm_s(s_s), "pair":  self.trunk2sm_z(s_z)},
            true_aa
        )
        recycle_s = s_s
        recycle_z = s_z
        recycle_bins=self.distogram(structure["positions"][-1][:, :, :3])
        return structure, s_s, s_z, recycle_s, recycle_z, recycle_bins

# Export Recycling
true_aa = torch.randint(0, 20, (1, 100))
s = torch.randn(1, 100, 1024)
z = torch.randn(1, 100, 100, 128)
trunk = trunkWrapper(model).to("cpu")
recycle_s = torch.randn(1, 100, 1024)  # Example input
recycle_z = torch.randn(1, 100, 100, 128) # Example input
recycle_bins = torch.randint(0, 15, (1, 100, 100))  # Example input
residx = torch.arange(100).unsqueeze(0)  # Shape: [batch_size, sequence_length]
mask = torch.ones(1, 100)  # Shape: [batch_size, sequence_length]


torch.onnx.export(
    trunk,
    (true_aa, s, z, recycle_s, recycle_z, recycle_bins,residx, mask),
    "structure_module_new.onnx",
    input_names=["aa","s_s_0","s_z_0","recycle_s", "recycle_z", "recycle_bins","residx","mask"],
    output_names=['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 'single',"s_s","s_z", "updated_recycle_s", "updated_recycle_z","updated_recycle_bins"],
    dynamic_axes={
        "aa": {0:"batch",1: "sequence_length"},
        "s_s_0": {0:"batch",1: "sequence_length"},
        "s_z_0": {0:"batch",1: "sequence_length", 2: "sequence_length"},
        "recycle_s": {0:"batch",1: "sequence_length"},
        "recycle_z": {0:"batch",1: "sequence_length", 2: "sequence_length"},
        "recycle_bins": {0:"batch",1: "sequence_length", 2: "sequence_length"},
        "residx": {0:"batch",1: "sequence_length"},
        "mask": {0:"batch",1: "sequence_length"},
    },
)

  assert residue_index.shape == mask.shape
  if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
  if((rots.shape != trans.shape[:-1]) or
  return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
  if(t.shape[-2:] != (4, 4)):
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [39]:
import torch
import torch.nn as nn

class RelativePosition(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.pairwise_positional_embedding=model.trunk.pairwise_positional_embedding

    def forward(self, z, residx, mask):
        return z + self.pairwise_positional_embedding(residx, mask=mask)

# Export RelativePosition
relative_position = RelativePosition(model).to("cpu")
z = torch.randn(1, 100, 100, 128)
residx = torch.arange(100).unsqueeze(0)  # Example input
mask = torch.ones(1, 100)  # Example input

torch.onnx.export(
    relative_position,
    (z, residx, mask),
    "relative_position.onnx",
    input_names=["s_z_0","residx", "mask"],
    output_names=["s_z"],
    dynamic_axes={
        "s_z_0":{0:"batch",1: "sequence_length",2: "sequence_length"},
        "residx": {0:"batch",1: "sequence_length"},
        "mask": {0:"batch",1: "sequence_length"},
    },
)

  assert residue_index.shape == mask.shape


In [40]:

# Create an instance of the model
relative_position = RelativePosition(model).to("cpu")

# Create example inputs
z = torch.randn(1, 100, 100, 128)
residx = torch.arange(100).unsqueeze(0)  # Shape: [batch_size, sequence_length]
mask = torch.ones(1, 100)  # Shape: [batch_size, sequence_length]

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output = relative_position(z, residx, mask).numpy()  # Convert to numpy for comparison

# Load the ONNX model
onnx_session = ort.InferenceSession("relative_position.onnx")

# Prepare the inputs for the ONNX model
onnx_input = {
    "s_z_0": z.numpy(),
    "residx": residx.numpy(),
    "mask": mask.numpy(),
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)

diff = np.abs(pytorch_output - onnx_output)
print(f"Max difference: {np.max(diff)}")
print(f"Mean difference: {np.mean(diff)}")
print(f"Number of differences > 1e-5: {np.sum(diff > 1e-5)}")

Max difference: 0.0
Mean difference: 0.0
Number of differences > 1e-5: 0


In [38]:
# Assuming `model` is already defined and has the required attributes
recycling = Recycling(model).cuda()  # Move to GPU if necessary

# Create example inputs
recycle_s = torch.randn(1, 100, 1024).cuda()  # Example input
recycle_z = torch.randn(1, 100, 100, 128).cuda()  # Example input
recycle_bins = torch.randint(0, 15, (1, 100, 100)).cuda()  # Example input

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output_s, pytorch_output_z = recycling(recycle_s, recycle_z, recycle_bins)
    pytorch_output_s = pytorch_output_s.cpu().numpy()  # Convert to numpy for comparison
    pytorch_output_z = pytorch_output_z.cpu().numpy()

# Load the ONNX model
onnx_session = ort.InferenceSession("recycling.onnx", providers=["CUDAExecutionProvider"])

# Prepare the inputs for the ONNX model
onnx_input = {
    "recycle_s": recycle_s.cpu().numpy(),  # Move to CPU for ONNX
    "recycle_z": recycle_z.cpu().numpy(),
    "recycle_bins": recycle_bins.cpu().numpy(),
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)
onnx_output_s, onnx_output_z = onnx_output[0], onnx_output[1]

# Compare the outputs

print("Outputs do not match!")
diff_s = np.abs(pytorch_output_s - onnx_output_s)
diff_z = np.abs(pytorch_output_z - onnx_output_z)
print(f"Max difference (recycle_s): {np.max(diff_s)}")
print(f"Mean difference (recycle_s): {np.mean(diff_s)}")
print(f"Number of differences > 1e-5 (recycle_s): {np.sum(diff_s > 1e-5)}")
print(f"Max difference (recycle_z): {np.max(diff_z)}")
print(f"Mean difference (recycle_z): {np.mean(diff_z)}")
print(f"Number of differences > 1e-5 (recycle_z): {np.sum(diff_z > 1e-5)}")

Outputs do not match!
Max difference (recycle_s): 2.86102294921875e-06
Mean difference (recycle_s): 1.414014292322463e-07
Number of differences > 1e-5 (recycle_s): 0
Max difference (recycle_z): 1.6689300537109375e-06
Mean difference (recycle_z): 6.774012462074097e-08
Number of differences > 1e-5 (recycle_z): 0




In [26]:
class TrunkToStructureModule(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.trunk2sm_s = model.trunk.trunk2sm_s
        self.trunk2sm_z = model.trunk.trunk2sm_z

    def forward(self, s_s, s_z):
        sm_s = self.trunk2sm_s(s_s)
        sm_z = self.trunk2sm_z(s_z)
        return sm_s, sm_z

# Export TrunkToStructureModule
trunk_to_sm = TrunkToStructureModule(model)
s_s = torch.randn(1, 100, 1024)  # Example input
s_z = torch.randn(1, 100, 100, 128)  # Example input

torch.onnx.export(
    trunk_to_sm,
    (s_s, s_z),
    "trunk_to_sm.onnx",
    input_names=["s_s", "s_z"],
    output_names=["sm_s", "sm_z"],
    dynamic_axes={
        "s_s": {0:"batch",1: "sequence_length"},
        "s_z": {0:"batch",1: "sequence_length", 2: "sequence_length"},
    },
)

In [28]:
# Assuming `model` is already defined and has the required attributes
trunk_to_sm = TrunkToStructureModule(model)  # Move to GPU if necessary

# Create example inputs
s_s = torch.randn(1, 100, 1024)  # Example input
s_z = torch.randn(1, 100, 100, 128)  # Example input

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output_sm_s, pytorch_output_sm_z = trunk_to_sm(s_s, s_z)
    pytorch_output_sm_s = pytorch_output_sm_s.cpu().numpy()  # Convert to numpy for comparison
    pytorch_output_sm_z = pytorch_output_sm_z.cpu().numpy()

# Load the ONNX model
onnx_session = ort.InferenceSession("trunk_to_sm.onnx", providers=["CUDAExecutionProvider"])

# Prepare the inputs for the ONNX model
onnx_input = {
    "s_s": s_s.cpu().numpy(),  # Move to CPU for ONNX
    "s_z": s_z.cpu().numpy(),
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)
onnx_output_sm_s, onnx_output_sm_z = onnx_output[0], onnx_output[1]

# Compare the outputs
if np.allclose(pytorch_output_sm_s, onnx_output_sm_s, atol=1e-5) and np.allclose(pytorch_output_sm_z, onnx_output_sm_z, atol=1e-5):
    print("Outputs match within tolerance!")
else:
    print("Outputs do not match!")
    diff_sm_s = np.abs(pytorch_output_sm_s - onnx_output_sm_s)
    diff_sm_z = np.abs(pytorch_output_sm_z - onnx_output_sm_z)
    print(f"Max difference (sm_s): {np.max(diff_sm_s)}")
    print(f"Mean difference (sm_s): {np.mean(diff_sm_s)}")
    print(f"Number of differences > 1e-5 (sm_s): {np.sum(diff_sm_s > 1e-5)}")
    print(f"Max difference (sm_z): {np.max(diff_sm_z)}")
    print(f"Mean difference (sm_z): {np.mean(diff_sm_z)}")
    print(f"Number of differences > 1e-5 (sm_z): {np.sum(diff_sm_z > 1e-5)}")

# Print intermediate values for debugging
print("PyTorch output_sm_s shape:", pytorch_output_sm_s.shape)
print("ONNX output_sm_s shape:", onnx_output_sm_s.shape)

print("PyTorch output_sm_z shape:", pytorch_output_sm_z.shape)
print("ONNX output_sm_z shape:", onnx_output_sm_z.shape)


Outputs match within tolerance!
PyTorch output_sm_s shape: (1, 100, 384)
ONNX output_sm_s shape: (1, 100, 384)
PyTorch output_sm_z shape: (1, 100, 100, 128)
ONNX output_sm_z shape: (1, 100, 100, 128)


In [29]:
class Distogram(nn.Module):
    def __init__(self, min_bin, max_bin, num_bins):
        super().__init__()
        self.min_bin = min_bin
        self.max_bin = max_bin
        self.num_bins = num_bins

    def forward(self, coords):
        boundaries = torch.linspace(self.min_bin, self.max_bin, self.num_bins - 1, device=coords.device)
        boundaries = boundaries**2
        N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
        b = CA - N
        c = C - CA
        a = b.cross(c, dim=-1)
        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
        dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
        bins = torch.sum(dists > boundaries, dim=-1)
        return bins

# Export Distogram
distogram = Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)
coords = torch.randn(1, 100, 3, 3)  # Example input

torch.onnx.export(
    distogram,
    (coords,),
    "distogram.onnx",
    input_names=["coords"],
    output_names=["bins"],
    dynamic_axes={
        "coords": {0:"batch",1: "sequence_length"},
    },
)

In [30]:
# Create an instance of the model
distogram = Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)

# Create example inputs
coords = torch.randn(1, 100, 3, 3)  # Example input (batch_size=1, seq_len=100, 3 atoms, 3 coordinates)

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output = distogram(coords).numpy()  # Convert to numpy for comparison

# Load the ONNX model
onnx_session = ort.InferenceSession("distogram.onnx")

# Prepare the inputs for the ONNX model
onnx_input = {
    "coords": coords.numpy(),
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)[0]

# Compare the outputs
if np.allclose(pytorch_output, onnx_output, atol=1e-5):
    print("Outputs match within tolerance!")
else:
    print("Outputs do not match!")
    diff = np.abs(pytorch_output - onnx_output)
    print(f"Max difference: {np.max(diff)}")
    print(f"Mean difference: {np.mean(diff)}")
    print(f"Number of differences > 1e-5: {np.sum(diff > 1e-5)}")

# Print intermediate values for debugging
print("PyTorch output shape:", pytorch_output.shape)
print("ONNX output shape:", onnx_output.shape)
print("PyTorch output (first 5x5):\n", pytorch_output[0, :5, :5])
print("ONNX output (first 5x5):\n", onnx_output[0, :5, :5])

Outputs match within tolerance!
PyTorch output shape: (1, 100, 100)
ONNX output shape: (1, 100, 100)
PyTorch output (first 5x5):
 [[0 1 1 3 1]
 [1 0 2 1 0]
 [1 2 0 4 1]
 [3 1 4 0 2]
 [1 0 1 2 0]]
ONNX output (first 5x5):
 [[0 1 1 3 1]
 [1 0 2 1 0]
 [1 2 0 4 1]
 [3 1 4 0 2]
 [1 0 1 2 0]]


In [22]:
import torch
import torch.nn as nn
from typing import Dict
import torch
import torch.nn as nn
from typing import Dict
from openfold.data.data_transforms import make_atom14_masks
from esm.esmfold.v1.categorical_mixture import categorical_lddt

from openfold.utils.loss import compute_predicted_aligned_error, compute_tm

import torch
import torch.nn as nn
from typing import Dict, Optional

class PostProcessingWrapper(nn.Module):
    def __init__(self, esmfold_model):
        super().__init__()
        self.esmfold = esmfold_model
        self.distogram_bins = 64
        self.lddt_bins = 50

    def forward(
        self,
        aa,
        structure
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for post-processing.
        """
        B, L=aa.shape
        mask = torch.ones_like(aa)
        residx = torch.arange(aa.shape[1], device=aa.device).expand_as(aa)

        # === Distogram Head ===
        disto_logits = self.esmfold.distogram_head(structure["s_z"])
        disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
        structure["distogram_logits"] = disto_logits

        lm_logits = self.esmfold.lm_head(structure["s_s"])
        structure["lm_logits"] = lm_logits

        structure["aatype"] = aa
        make_atom14_masks(structure)
        for k in [
            "atom14_atom_exists",
            "atom37_atom_exists",
        ]:
            structure[k] *= mask.unsqueeze(-1)
        structure["residue_index"] = residx

        lddt_head = self.esmfold.lddt_head(structure["states"]).reshape(
            structure["states"].shape[0], B, L, -1, self.lddt_bins
        )
        structure["lddt_head"] = lddt_head
        plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
        structure["plddt"] = (
            100 * plddt
        )  # we predict plDDT between 0 and 1, scale to be between 0 and 100.

        ptm_logits = self.esmfold.ptm_head(structure["s_z"])

        seqlen = mask.type(torch.int64).sum(1)
        structure["ptm_logits"] = ptm_logits
        structure["ptm"] = torch.stack(
            [
                compute_tm(
                    batch_ptm_logits[None, :sl, :sl],
                    max_bins=31,
                    no_bins=self.distogram_bins,
                )
                for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
            ]
        )
        structure.update(
            compute_predicted_aligned_error(
                ptm_logits, max_bin=31, no_bins=self.distogram_bins
            )
        )

        return structure

In [23]:
import torch
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval()
esmfold_model=model

# Create an instance of the wrapper
post_processing_wrapper = PostProcessingWrapper(esmfold_model)

# Set the model to evaluation mode
post_processing_wrapper.eval()

# Create dummy inputs for tracing
B, L = 1, 100  # Batch size and sequence length
aa = torch.randint(0, 20, (B, L))  # Amino acid indices
structure={'frames': torch.randn([8, B, L, 7]),
           'sidechain_frames': torch.randn([8, B, L, 8, 4, 4]),
           'unnormalized_angles': torch.randn([8, B, L, 7, 2]), 
           'angles': torch.randn([8, B, L, 7, 2]), 
           'positions': torch.randn([8, B, L, 14, 3]), 
           'states': torch.randn([8, B, L, 384]), 
           's_s': torch.randn([B, L, esmfold_model.cfg.trunk.sequence_state_dim]), 
           's_z': torch.randn([B, L, L, esmfold_model.cfg.trunk.pairwise_state_dim])}
result=post_processing_wrapper(aa,structure)

In [27]:
import torch
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval()
esmfold_model=model

# Create an instance of the wrapper
post_processing_wrapper = PostProcessingWrapper(esmfold_model)

# Set the model to evaluation mode
post_processing_wrapper.eval()

# Create dummy inputs for tracing
B, L = 1, 100  # Batch size and sequence length
aa = torch.randint(0, 20, (B, L))  # Amino acid indices
structure={'frames': torch.randn([8, B, L, 7]),
           'sidechain_frames': torch.randn([8, B, L, 8, 4, 4]),
           'unnormalized_angles': torch.randn([8, B, L, 7, 2]), 
           'angles': torch.randn([8, B, L, 7, 2]), 
           'positions': torch.randn([8, B, L, 14, 3]), 
           'states': torch.randn([8, B, L, 384]), 
           's_s': torch.randn([B, L, esmfold_model.cfg.trunk.sequence_state_dim]), 
           's_z': torch.randn([B, L, L, esmfold_model.cfg.trunk.pairwise_state_dim])}
# Export to ONNX
onnx_file_path = "post_processing.onnx"
torch.onnx.export(
    post_processing_wrapper,
    (aa,structure,{}),
    onnx_file_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=[
        "aa",
        'frames_0',
        'sidechain_frames_0',
        'unnormalized_angles_0', 
        'angles_0', 
        'positions_0', 
        'states_0', 
        's_s_0', 
        's_z_0'],
    output_names=[
        'frames', 'sidechain_frames', 'unnormalized_angles', 'angles',
        'positions', 'states', 's_s', 's_z', 'distogram_logits', 'lm_logits', 
        'aatype', 'atom14_atom_exists', 'residx_atom14_to_atom37', 
        'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', 
        'lddt_head', 'plddt', 'ptm_logits', 'ptm', 'aligned_confidence_probs', 
        'predicted_aligned_error', 'max_predicted_aligned_error'
        ],
    dynamic_axes={
        "aa": {0: "batch_size", 1: "seq_len"},
        'frames_0': {1: "batch_size", 2: "seq_len"},
        'sidechain_frames_0': {1: "batch_size", 2: "seq_len"},
        'unnormalized_angles_0': {1: "batch_size", 2: "seq_len"}, 
        'angles_0': {1: "batch_size", 2: "seq_len"}, 
        'positions_0': {1: "batch_size", 2: "seq_len"}, 
        "states_0": {1: "batch_size", 2: "seq_len"},
        "s_s_0": {0: "batch_size", 1: "seq_len"},
        "s_z_0": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
    },
)

print(f"Model exported to {onnx_file_path}")

  for batch_ptm_logits, sl in zip(ptm_logits, seqlen)


Model exported to post_processing.onnx


In [21]:
# Assuming `esmfold_model` is already defined and has the required attributes
post_processing_wrapper = PostProcessingWrapper(model) # Move to GPU if necessary

# Create example inputs
B, L = 1, 100  # Batch size and sequence length
aa = torch.randint(0, 20, (B, L))  # Amino acid indices
structure={'frames': torch.randn([8, B, L, 7]),
           'sidechain_frames': torch.randn([8, B, L, 8, 4, 4]),
           'unnormalized_angles': torch.randn([8, B, L, 7, 2]), 
           'angles': torch.randn([8, B, L, 7, 2]), 
           'positions': torch.randn([8, B, L, 14, 3]), 
           'states': torch.randn([8, B, L, 384]), 
           's_s': torch.randn([B, L, esmfold_model.cfg.trunk.sequence_state_dim]), 
           's_z': torch.randn([B, L, L, esmfold_model.cfg.trunk.pairwise_state_dim])}

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output = post_processing_wrapper(aa,structure)
pytorch_output.keys()

dict_keys(['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 's_s', 's_z', 'distogram_logits', 'lm_logits', 'aatype', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', 'lddt_head', 'plddt', 'ptm_logits', 'ptm', 'aligned_confidence_probs', 'predicted_aligned_error', 'max_predicted_aligned_error'])

In [None]:
# Load the ONNX model
onnx_session = ort.InferenceSession("post_processing.onnx", providers=["CUDAExecutionProvider"])

# Prepare the inputs for the ONNX model
onnx_input = {
    "aa": aa.numpy()
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)
onnx_output = {
    "distogram_logits": onnx_output[0],
    "plddt": onnx_output[1],
    "ptm": onnx_output[2],
    "predicted_aligned_error": onnx_output[5],
    "atom14_atom_exists": onnx_output[6],
    "atom37_atom_exists": onnx_output[7],
    "residue_index": onnx_output[8],
}

# Compare the outputs
for key in ['distogram_logits', 'plddt', 'ptm', 'atom14_atom_exists', 'atom37_atom_exists', 'residue_index']:
    if np.allclose(pytorch_output[key], onnx_output[key], atol=1e-5):
        print(f"{key} outputs match within tolerance!")
    else:
        print(f"{key} outputs do not match!")
        diff = np.abs(pytorch_output[key] - onnx_output[key])
        print(f"Max difference: {np.max(diff)}")
        print(f"Mean difference: {np.mean(diff)}")
        print(f"Number of differences > 1e-5: {np.sum(diff > 1e-5)}")

    print(f"PyTorch {key} shape:", pytorch_output[key].shape)
    print(f"ONNX {key} shape:", onnx_output[key].shape)


distogram_logits outputs match within tolerance!
PyTorch distogram_logits shape: torch.Size([1, 100, 100, 64])
ONNX distogram_logits shape: (1, 100, 100, 64)
plddt outputs match within tolerance!
PyTorch plddt shape: torch.Size([1, 100, 37])
ONNX plddt shape: (1, 100, 37)
ptm outputs match within tolerance!
PyTorch ptm shape: torch.Size([1])
ONNX ptm shape: (1,)
atom14_atom_exists outputs match within tolerance!
PyTorch atom14_atom_exists shape: torch.Size([1, 100, 14])
ONNX atom14_atom_exists shape: (1, 100, 14)
atom37_atom_exists outputs match within tolerance!
PyTorch atom37_atom_exists shape: torch.Size([1, 100, 37])
ONNX atom37_atom_exists shape: (1, 100, 37)
residue_index outputs match within tolerance!
PyTorch residue_index shape: torch.Size([1, 100])
ONNX residue_index shape: (1, 100)


In [None]:
class StructureModuleTransitionWrapper(nn.Module):
    def __init__(self, transition_module):
        super(
              TransitionWrapper, self).__init__()
        self.transition_module = transition_module

    def forward(self, s):
        """
        Args:
            s: Single representation tensor of shape [B, L, C_s]
        Returns:
            Updated single representation tensor of shape [B, L, C_s]
        """
        output = self.transition_module(s)
        return output

# Create dummy inputs
s = torch.randn(B, L, 384)  # Single representation

# Create an instance of the wrapper
transition_wrapper = StructureModuleTransitionWrapper(model.trunk.structure_module.transition).eval()

# Export to ONNX
onnx_file_path = "transition_wrapper.onnx"
torch.onnx.export(
    transition_wrapper,
    (s,),
    onnx_file_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["s"],
    output_names=["output"],
    dynamic_axes={
        "s": {0: "batch_size", 1: "seq_len"},
    },
    
)

print(f"StructureModuleTransitionWrapper exported to {onnx_file_path}")
# Run PyTorch inference
with torch.no_grad():
    pytorch_output = transition_wrapper(s)

# Run ONNX inference
import onnxruntime as ort
ort_session = ort.InferenceSession("transition_wrapper.onnx")
onnx_output = ort_session.run(None, {
    "s": s.numpy(),
})[0]

# Compare outputs
print("Difference:", torch.abs(pytorch_output - torch.tensor(onnx_output)).max())

StructureModuleTransitionWrapper exported to transition_wrapper.onnx
Difference: tensor(1.3113e-06)


In [48]:
class AngleResnetWrapper(nn.Module):
    def __init__(self, angle_resnet):
        super(AngleResnetWrapper, self).__init__()
        self.angle_resnet = angle_resnet

    def forward(self, s, s_initial):
        """
        Args:
            s: Single representation tensor of shape [B, L, C_s]
            s_initial: Initial single representation tensor of shape [B, L, C_s]
        Returns:
            unnormalized_angles: Tensor of shape [B, L, no_angles, 2]
            angles: Tensor of shape [B, L, no_angles, 2]
        """
        unnormalized_angles, angles = self.angle_resnet(s, s_initial)
        return unnormalized_angles, angles
# Create dummy inputs
B,L=1,100
s = torch.randn(B, L,384)  # Single representation
s_initial = torch.randn(B, 100, 384)  # Initial single representation

# Create an instance of the wrapper
angle_resnet_wrapper = AngleResnetWrapper(model.trunk.structure_module.angle_resnet).eval()

# Export to ONNX
onnx_file_path = "angle_resnet_wrapper.onnx"
torch.onnx.export(
    angle_resnet_wrapper,
    (s, s_initial),
    onnx_file_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["s", "s_initial"],
    output_names=["unnormalized_angles", "angles"],
    dynamic_axes={
        "s": {0: "batch_size", 1: "seq_len"},
        "s_initial": {0: "batch_size", 1: "seq_len"},
    },
)

print(f"AngleResnetWrapper exported to {onnx_file_path}")

# Run PyTorch inference
with torch.no_grad():
    pytorch_output = angle_resnet_wrapper(s, s_initial)

# Run ONNX inference
import onnxruntime as ort
ort_session = ort.InferenceSession("angle_resnet_wrapper.onnx")
onnx_output = ort_session.run(None, {"s":s.numpy(), "s_initial": s_initial.numpy()})[1]

# Compare outputs
print("Difference:", torch.abs(pytorch_output[1] - torch.tensor(onnx_output)).max())

AngleResnetWrapper exported to angle_resnet_wrapper.onnx
Difference: tensor(5.9605e-06)


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [50]:
class BackboneUpdateWrapper(nn.Module):
    def __init__(self, backbone_update):
        super(BackboneUpdateWrapper, self).__init__()
        self.backbone_update = backbone_update

    def forward(self, s):
        """
        Args:
            s: Single representation tensor of shape [B, L, C_s]
        Returns:
            Update vector tensor of shape [B, L, 6]
        """
        update = self.backbone_update(s)
        return update
# Create dummy inputs
s = torch.randn(B, L, 384)  # Single representation

# Create an instance of the wrapper
backbone_update_wrapper = BackboneUpdateWrapper(model.trunk.structure_module.bb_update).eval()

# Export to ONNX
onnx_file_path = "backbone_update_wrapper.onnx"
torch.onnx.export(
    backbone_update_wrapper,
    (s,),
    onnx_file_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["s"],
    output_names=["update"],
    dynamic_axes={
        "s": {0: "batch_size", 1: "seq_len"},
    },
)

print(f"BackboneUpdateWrapper exported to {onnx_file_path}")
# Run PyTorch inference
with torch.no_grad():
    pytorch_output = backbone_update_wrapper(s)

# Run ONNX inference
import onnxruntime as ort
ort_session = ort.InferenceSession("backbone_update_wrapper.onnx")
onnx_output = ort_session.run(None, {"s": s.numpy()})[0]

# Compare outputs
print("Difference:", torch.abs(pytorch_output - torch.tensor(onnx_output)).max())

BackboneUpdateWrapper exported to backbone_update_wrapper.onnx
Difference: tensor(2.1458e-06)


In [51]:
class IPAWrapper(nn.Module):
    def __init__(self, ipa_module):
        super(IPAWrapper, self).__init__()
        self.ipa_module = ipa_module

    def forward(self, s, z, rigids_tensor, mask):
        """
        Args:
            s: Single representation tensor of shape [B, L, C_s]
            z: Pair representation tensor of shape [B, L, L, C_z]
            rigids_tensor: Rigid transformations tensor of shape [B, L, 7] (quaternion + translation)
            mask: Mask tensor of shape [B, L]
        Returns:
            Updated single representation tensor of shape [B, L, C_s]
        """
        # Convert rigids tensor to Rigid object
        rigids = Rigid.from_tensor_7(rigids_tensor)  # Implement this function

        # Run IPA
        output = self.ipa_module(s, z, rigids, mask)

        return output
    
# Create dummy inputs
B, L = 1, 100  # Batch size and sequence length
s = torch.randn(B, L, 384)  # Single representation
z = torch.randn(B, L, L, 128)  # Pair representation
rigids_tensor = torch.randn(B, L, 7)  # Rigid transformations (quaternion + translation)
mask = torch.ones(B, L)  # Mask

# Create an instance of the wrapper
ipa_wrapper = IPAWrapper(model.trunk.structure_module.ipa).eval()

# Export to ONNX
onnx_file_path = "ipa_wrapper.onnx"
torch.onnx.export(
    ipa_wrapper,
    (s, z, rigids_tensor, mask),
    onnx_file_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["s", "z", "rigids", "mask"],
    output_names=["output"],
    dynamic_axes={
        "s": {0: "batch_size", 1: "seq_len"},
        "z": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
        "rigids": {0: "batch_size", 1: "seq_len"},
        "mask": {0: "batch_size", 1: "seq_len"},
    },
)

print(f"IPAWrapper exported to {onnx_file_path}")

# Run PyTorch inference
with torch.no_grad():
    pytorch_output = ipa_wrapper(s, z, rigids_tensor, mask)

# Run ONNX inference
import onnxruntime as ort
ort_session = ort.InferenceSession("ipa_wrapper.onnx")
onnx_output = ort_session.run(None, {
    "s": s.numpy(),
    "z": z.numpy(),
    "rigids": rigids_tensor.numpy(),
    "mask": mask.numpy(),
})[0]

# Compare outputs
print("Difference:", torch.abs(pytorch_output - torch.tensor(onnx_output)).max())

  if(t.shape[-1] != 7):
  if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
  if((rots.shape != trans.shape[:-1]) or


IPAWrapper exported to ipa_wrapper.onnx
Difference: tensor(0.0042)


In [1]:
import torch
import torch.nn as nn
class StructureModuleWrapper(nn.Module):
    def __init__(self, folding_trunk):
        super().__init__()
        self.structure_module = folding_trunk.structure_module
        self.trunk2sm_s = model.trunk.trunk2sm_s
        self.trunk2sm_z = model.trunk.trunk2sm_z

    def forward(self, s_s, s_z, true_aa):
        """
        Inputs:
          s_s:           B x L x C            sequence features
          s_z:           B x L x L x C        pairwise features
          true_aa:       B x L                true amino acid indices

        Outputs:
          structure:     dict                 predicted structure
        """
        structure = self.structure_module(
            {"single": self.trunk2sm_s(s_s), "pair":  self.trunk2sm_z(s_z)},
            true_aa
        )
        return structure

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Initialize the structure module
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval().cuda()
folding_trunk=model.trunk
structure_module_wrapper = StructureModuleWrapper(folding_trunk).eval()

# Example inputs
batch_size=1
true_aa = torch.randint(0, 20, (batch_size, 65), device="cuda")
s_s = torch.randn((batch_size, 65, 1024)).cuda()
s_z = torch.randn((batch_size, 65, 65, 128)).cuda()

structure_module_wrapper(s_s, s_z, true_aa).keys()

dict_keys(['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 'single'])

In [4]:
# Initialize the structure module
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval().cuda()
folding_trunk=model.trunk
structure_module_wrapper = StructureModuleWrapper(folding_trunk).eval()

# Example inputs
batch_size=1
true_aa = torch.randint(0, 20, (batch_size, 65), device="cuda")
s_s = torch.randn((batch_size, 65, 1024)).cuda()
s_z = torch.randn((batch_size, 65, 65, 128)).cuda()

# Export to ONNX
torch.onnx.export(
    structure_module_wrapper,
    (s_s, s_z, true_aa),
    "structure_module.onnx",
    input_names=["s_s", "s_z", "true_aa"],
    output_names=['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 'single'],
    dynamic_axes={
        "s_s": {0:"B", 1: "L"},
        "s_z": {0:"B", 1: "L", 2: "L"},
        "true_aa": {0:"B", 1: "L"}
        }
    )

  if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
  if((rots.shape != trans.shape[:-1]) or
  torch.tensor(
  torch.tensor(
  torch.tensor(
  torch.tensor(
  if(t.shape[-2:] != (4, 4)):
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [55]:
import onnxruntime as ort
import numpy as np
folding_trunk = model.trunk
structure_module_wrapper = StructureModuleWrapper(folding_trunk).to("cpu")

# Example inputs
batch_size = 1
true_aa = torch.randint(0, 20, (batch_size, 65))  # Example input (batch_size=1, seq_len=65)
s_s = torch.randn((batch_size, 65, 384))  # Example input (batch_size=1, seq_len=65, sequence_state_dim=384)
s_z = torch.randn((batch_size, 65, 65, 128)) # Example input (batch_size=1, seq_len=65, seq_len=65, pairwise_state_dim=128)

# Run inference with the PyTorch model
with torch.no_grad():
    pytorch_output = structure_module_wrapper(s_s, s_z, true_aa)
    pytorch_output = {k: v.numpy() for k, v in pytorch_output.items()}  # Convert to numpy for comparison

# Load the ONNX model
onnx_session = ort.InferenceSession("structure_module.onnx", providers=["CUDAExecutionProvider"])

# Prepare the inputs for the ONNX model
onnx_input = {
    "s_s": s_s.numpy(),  # Move to CPU for ONNX
    "s_z": s_z.numpy(),
    "true_aa": true_aa.numpy(),
}

# Run inference with the ONNX model
onnx_output = onnx_session.run(None, onnx_input)
onnx_output = {
    "frames": onnx_output[0],
    "sidechain_frames": onnx_output[1],
    "unnormalized_angles": onnx_output[2],
    "angles": onnx_output[3],
    "positions": onnx_output[4],
    "states": onnx_output[5],
    "single": onnx_output[6],
}

# Compare the outputs
for key in pytorch_output:
    if np.allclose(pytorch_output[key], onnx_output[key], atol=1e-4):
        print(f"{key} outputs match within tolerance!")
    else:
        print(f"{key} outputs do not match!")
        diff = np.abs(pytorch_output[key] - onnx_output[key])
        print(f"Max difference: {np.max(diff)}")
        print(f"Mean difference: {np.mean(diff)}")
        print(f"Number of differences > 1e-4: {np.sum(diff > 1e-4)}")

# Print intermediate values for debugging
for key in pytorch_output:
    print(f"PyTorch {key} shape:", pytorch_output[key].shape)
    print(f"ONNX {key} shape:", onnx_output[key].shape)

frames outputs do not match!
Max difference: 0.00017702579498291016
Mean difference: 7.190011729107937e-06
Number of differences > 1e-4: 2
sidechain_frames outputs do not match!
Max difference: 0.0002570152282714844
Mean difference: 4.123567578062648e-06
Number of differences > 1e-4: 47
unnormalized_angles outputs do not match!
Max difference: 0.00021857023239135742
Mean difference: 4.058340152823803e-07
Number of differences > 1e-4: 3
angles outputs do not match!
Max difference: 0.00028818845748901367
Mean difference: 4.814603471459122e-07
Number of differences > 1e-4: 3
positions outputs do not match!
Max difference: 0.00027441978454589844
Mean difference: 1.0693239346437622e-05
Number of differences > 1e-4: 65
states outputs match within tolerance!
single outputs match within tolerance!
PyTorch frames shape: (8, 1, 65, 7)
ONNX frames shape: (8, 1, 65, 7)
PyTorch sidechain_frames shape: (8, 1, 65, 8, 4, 4)
ONNX sidechain_frames shape: (8, 1, 65, 8, 4, 4)
PyTorch unnormalized_angles s

In [57]:
import torch
import torch.nn as nn
from typing import Dict

class Distogram(nn.Module):
    def __init__(self, min_bin, max_bin, num_bins):
        super().__init__()
        self.min_bin = min_bin
        self.max_bin = max_bin
        self.num_bins = num_bins

    def forward(self, coords):
        boundaries = torch.linspace(self.min_bin, self.max_bin, self.num_bins - 1, device=coords.device)
        boundaries = boundaries**2
        N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
        b = CA - N
        c = C - CA
        a = b.cross(c, dim=-1)
        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
        dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
        bins = torch.sum(dists > boundaries, dim=-1)
        return bins

class TrunkWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.structure_module = model.trunk.structure_module
        self.trunk2sm_s = model.trunk.trunk2sm_s
        self.trunk2sm_z = model.trunk.trunk2sm_z
        self.recycle_s_norm = model.trunk.recycle_s_norm 
        self.recycle_z_norm = model.trunk.recycle_z_norm 
        self.recycle_disto = model.trunk.recycle_disto 
        self.pairwise_positional_embedding = model.trunk.pairwise_positional_embedding
        self.distogram=Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)
    
    def forward(
        self,
        true_aa,
        s_s_0: torch.Tensor,
        s_z_0: torch.Tensor,
        recycle_s: torch.Tensor,
        recycle_z: torch.Tensor,
        recycle_bins: torch.Tensor
    ):
        """
        Forward pass for the trunk.
        """
        mask = torch.ones_like(true_aa)
        L = true_aa.shape[1]
        device = true_aa.device
        
        residx = torch.arange(L, device=device).expand_as(true_aa)        
        
        def trunk_iter(z, residx, mask):
            z = z + self.pairwise_positional_embedding(residx, mask=mask)
            return z
    
        recycle_s = self.recycle_s_norm(recycle_s.detach())
        recycle_z = self.recycle_z_norm(recycle_z.detach())
        recycle_z += self.recycle_disto(recycle_bins.detach())
        s_z = trunk_iter( s_z_0 + recycle_z, residx, mask)
        s_s = s_s_0 + recycle_s
        # === Structure module ===
        structure = self.structure_module(
            {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
            true_aa,
            mask.float(),
        )
        recycle_s = s_s
        recycle_z = s_z
        # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
        recycle_bins =Distogram(min_bin=3.375, max_bin=21.375, num_bins=15)(structure["positions"][-1][:, :, :3])
        return recycle_s,recycle_z,recycle_bins,s_s, s_z,structure

In [None]:
import torch
import esm
model = esm.pretrained.esmfold_structure_module_only_8M()
model = model.eval()
# Create dummy inputs
batch_size, seq_len = 1, 50
aa = torch.randint(0, 20, (batch_size, seq_len), dtype=torch.long).to("cuda")
s_s_0 = torch.randn(batch_size, seq_len, 1024).to("cuda")
s_z_0 = torch.randn(batch_size, seq_len, seq_len, 128).to("cuda")
recycle_s = torch.randn(1, seq_len, 1024).to("cuda") # Example input
recycle_z = torch.randn(1, seq_len, seq_len, 128).to("cuda") # Example input
recycle_bins = torch.randint(0, 15, (1, seq_len, seq_len)).to("cuda")  # Example input


# Initialize the wrapper
trunk_wrapper = TrunkWrapper(model).to("cuda").eval()

# Export to ONNX
torch.onnx.export(
    trunk_wrapper,
    (aa, s_s_0, s_z_0, recycle_s,recycle_z,recycle_bins),
    "trunk_wrapper.onnx",
    input_names=["s_s_0", "s_z_0", "recycle_s","recycle_z","recycle_bins"],
    dynamic_axes={
        "s_s_0": {0:"batch",1: "seq_len"},
        "s_z_0": {0:"batch",1: "seq_len", 2: "seq_len"},
        "recycle_s": {0:"batch",1: "seq_len"},
        "recycle_z": {0:"batch",1: "seq_len",2: "seq_len"},
        "recycle_bins": {0:"batch",1: "seq_len",2: "seq_len"},
        
    },
    verbose=True,
)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [None]:
import h5py
import numpy as np

h = h5py.File('./weights.h5', 'w')
data=model.state_dict()
for k,v in data.items():
    h.create_dataset(k, data=v.cpu().numpy())
h.close()