In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import onnx
import io

## Generate forward-only graph

In [27]:
# Pytorch class that we will use to generate the graphs.
class SimpleNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create an instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = SimpleNet(input_size, hidden_size, output_size).to(device)

## Freezing the parameters

In [35]:
for param in pt_model.parameters():
    param.requires_grad = True

In [36]:
for name, param in pt_model.named_parameters():
    print('Name: ', name, 'Requires_Grad: ', param.requires_grad)

Name:  fc1.weight Requires_Grad:  True
Name:  fc1.bias Requires_Grad:  True
Name:  fc2.weight Requires_Grad:  True
Name:  fc2.bias Requires_Grad:  True


In [56]:
#Freezing them explicitly
pt_model.fc1.weight.requires_grad=False
pt_model.fc1.bias.requires_grad=False
pt_model.fc2.weight.required_grad=True
pt_model.fc2.bias.required_grad=True

## Getting the ONNX export ready

In [57]:
# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

model_outputs = pt_model(*model_inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

## Method 1 : Create training, eval, optimizer graph and checkpoint at once

In [58]:
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="SimpleNet",
    additional_output_names=["output"])

In [59]:
model = onnx.load("SimpleNET/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex784]
  %target[FLOAT, batch_sizex10]
  %fc1.weight[FLOAT, 500x784]
  %fc1.bias[FLOAT, 500]
  %fc2.weight[FLOAT, 10x500]
  %fc2.bias[FLOAT, 10]
  %fc2.weight_grad.accumulation.buffer[FLOAT, 10x500]
  %fc2.bias_grad.accumulation.buffer[FLOAT, 10]
  %lazy_reset_grad[BOOL, 1]
) initializers (
  %onnx::pow_exponent::73[FLOAT, 1]
  %onnx::reducemean_output::76_grad[FLOAT, scalar]
  %/fc2/Gemm_Grad/ReduceAxes_for_/fc2/Gemm_Grad/dC_reduced[INT64, 1]
  %OneConstant_Type1[FLOAT, 1]
) {
  %/fc1/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%input, %fc1.weight, %fc1.bias)
  %/relu/Relu_output_0 = Relu(%/fc1/Gemm_output_0)
  %output = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%/relu/Relu_output_0, %fc2.weight, %fc2.bias)
  %onnx::sub_output::71 = Sub(%output, %target)
  %onnx::pow_output::74 = Pow(%onnx::sub_output::71, %onnx::pow_exponent::73)
  %onnx::ReduceMean::77_Grad/Sized_X = Size(%onnx::pow_output::74)
  %onn

In [19]:
print ([name for name,param in pt_model.named_parameters()])

['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']


## Method 2 : Create training, eval, optimizer graph and checkpoint separately

In [None]:
# Creating a class with a Loss function.
class SimpleNetTrainingBlock(onnxblock.TrainingBlock):
    def __init__(self):
        super(SimpleNetTrainingBlock, self).__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss() #try a different loss

    def build(self, output_name):
        return self.loss(output_name), output_name

In [None]:
# Build the onnx model with loss
training_block = SimpleNetTrainingBlock()
for param in onnx_model.graph.initializer:
    print(param.name)
    training_block.requires_grad(param.name, True)

# Building training graph and eval graph.
model_params = None
with onnxblock.base(onnx_model):
    _ = training_block(*[output.name for output in onnx_model.graph.output])
    training_model, eval_model = training_block.to_model_proto()
    model_params = training_block.parameters()

# Building the optimizer graph
optimizer_block = onnxblock.optim.AdamW()
with onnxblock.empty_base() as accessor:
    _ = optimizer_block(model_params)
    optimizer_model = optimizer_block.to_model_proto()

In [None]:
# Generate the training graph
onnx.save(training_model, "training_model.onnx")

In [None]:
#If needed
onnx.save(optimizer_model, "optimizer_model.onnx")
onnx.save(eval_model, "eval_model.onnx")