In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import io

In [2]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size)
        )
        self.layer_norm1 = nn.LayerNorm(embed_size)
        self.layer_norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Multi-head self-attention
        attn_output, _ = self.attention(x, x, x, attn_mask=mask)
        attn_output = self.dropout(attn_output)
        # Add and norm
        out1 = self.layer_norm1(x + attn_output)
        
        # Feed forward network
        ff_output = self.feed_forward(out1)
        ff_output = self.dropout(ff_output)
        # Add and norm
        out2 = self.layer_norm2(out1 + ff_output)
        
        return out2


In [3]:
# Example usage of TransformerBlock
input_tensor = torch.randn(10, 32, 512)  # (sequence_length, batch_size, embed_size)
transformer_block = TransformerBlock(embed_size=512, heads=8)
output_tensor = transformer_block(input_tensor)
print(output_tensor.shape)  # should print: torch.Size([10, 32, 512])


torch.Size([10, 32, 512])


In [4]:
model_outputs = transformer_block(input_tensor)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

In [5]:
f = io.BytesIO()
torch.onnx.export(
    transformer_block,
    input_tensor,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [6]:
requires_grad = [name for name, param in transformer_block.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in transformer_block.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="Transformer",
    additional_output_names=["output"])

In [7]:
model = onnx.load("Transformer/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex32x512]
  %target[FLOAT, batch_sizex32x512]
  %attention.in_proj_weight[FLOAT, 1536x512]
  %attention.in_proj_bias[FLOAT, 1536]
  %attention.out_proj.weight[FLOAT, 512x512]
  %attention.out_proj.bias[FLOAT, 512]
  %feed_forward.0.weight[FLOAT, 2048x512]
  %feed_forward.0.bias[FLOAT, 2048]
  %feed_forward.2.weight[FLOAT, 512x2048]
  %feed_forward.2.bias[FLOAT, 512]
  %layer_norm1.weight[FLOAT, 512]
  %layer_norm1.bias[FLOAT, 512]
  %layer_norm2.weight[FLOAT, 512]
  %layer_norm2.bias[FLOAT, 512]
  %attention.in_proj_weight_grad.accumulation.buffer[FLOAT, 1536x512]
  %attention.in_proj_bias_grad.accumulation.buffer[FLOAT, 1536]
  %attention.out_proj.weight_grad.accumulation.buffer[FLOAT, 512x512]
  %attention.out_proj.bias_grad.accumulation.buffer[FLOAT, 512]
  %feed_forward.0.weight_grad.accumulation.buffer[FLOAT, 2048x512]
  %feed_forward.0.bias_grad.accumulation.buffer[FLOAT, 2048]
  %feed_forward.2.weight_grad.accumulation.buffer[