In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import onnx
import io

In [2]:
class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [3]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [4]:
#x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2
batch_size=1
print(d_in)

3


In [5]:
torch.manual_seed(123)
device = "cpu"
sa = SelfAttention(d_in, d_out).to(device)
print(sa(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [6]:
model_outputs = sa(inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

In [7]:
f = io.BytesIO()
torch.onnx.export(
    sa,
    inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [8]:
requires_grad = [name for name, param in sa.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in sa.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="Attention",
    additional_output_names=["output"])

In [9]:
model = onnx.load("Attention/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex3]
  %target[FLOAT, batch_sizex2]
  %W_query[FLOAT, 3x2]
  %W_key[FLOAT, 3x2]
  %W_value[FLOAT, 3x2]
  %W_query_grad.accumulation.buffer[FLOAT, 3x2]
  %W_key_grad.accumulation.buffer[FLOAT, 3x2]
  %W_value_grad.accumulation.buffer[FLOAT, 3x2]
  %lazy_reset_grad[BOOL, 1]
) initializers (
  %onnx::pow_exponent::3[FLOAT, 1]
  %/Pow_output_0[FLOAT, scalar]
  %onnx::reducemean_output::6_grad[FLOAT, scalar]
  %OneConstant_Type1[FLOAT, 1]
) {
  %/MatMul_2_output_0 = MatMul(%input, %W_value)
  %/MatMul_1_output_0 = MatMul(%input, %W_query)
  %/MatMul_output_0 = MatMul(%input, %W_key)
  %/Transpose_output_0 = Transpose[perm = [1, 0]](%/MatMul_output_0)
  %/MatMul_3_output_0 = MatMul(%/MatMul_1_output_0, %/Transpose_output_0)
  %/Div_output_0 = Div(%/MatMul_3_output_0, %/Pow_output_0)
  %/Softmax_output_0 = Softmax[axis = -1](%/Div_output_0)
  %output = MatMul(%/Softmax_output_0, %/MatMul_2_output_0)
  %onnx::sub_output::1 = Sub(%output, %ta

In [10]:
print(requires_grad)

['W_query', 'W_key', 'W_value']


In [11]:
print(frozen_params)

[]
