**Table of contents**<a id='toc0_'></a>    
- [Load and run model predictions](#toc1_)    
  - [Load the model](#toc1_1_)    
  - [Model Inference](#toc1_2_)    
  - [Exporting the model to ONNX](#toc1_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [20]:
!pip install onnxruntime

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [21]:
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

# <a id='toc1_'></a>[Load and run model predictions](#toc0_)

## <a id='toc1_1_'></a>[Load the model](#toc0_)

In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [6]:
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

## <a id='toc1_2_'></a>[Model Inference](#toc0_)

- Optimizing a models to run on a variety of platforms and programming languages is difficult. It's very time-consuming to maximize performance across all the different combinations of frameworks and hardware. 
- The Open Neural Network Exchange (ONNX) runtime provides a solution for you to train once and accelerate inference on any hardware, cloud, or edge devices.
- ONNX is a common format supported by a number of vendors to share neural networks and other machine learning models. You can use ONNX format to do inference on your model on other programming languages and frameworks such as Java, JavaScript, C# and ML.NET.

> **NOTE** ONNX 模型使您能够在不同平台上以不同编程语言运行推理

## <a id='toc1_3_'></a>[Exporting the model to ONNX](#toc0_)

- PyTorch also has native ONNX export support.
- Given the dynamic nature of the PyTorch execution graph, however, the export process must traverse the execution graph to produce a persisted ONNX model.
- For this reason, a test variable of the appropriate size should be passed in to the export routine 
- (in our case, we will create a dummy zero tensor of the correct size. You can get the size from the shape fuction on your training dataset with tensor.shape):

In [7]:
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

In [27]:
# 我们将使用测试数据集作为示例数据，从 ONNX 模型进行推理以进行预测。
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
print(type(x),type(y))
sample_img, _ = test_data[0]
sample_img.shape

<class 'torch.Tensor'> <class 'int'>


torch.Size([1, 28, 28])

In [15]:
# 使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型，请调用 run 并传入您想要返回的输出列表（如果您需要所有输出，请留空）和输入值的映射
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')


Predicted: "Ankle boot", Actual: "Ankle boot"
