In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

import numpy as np

import onnxmltools
import frontmatter as fm

`nn.Linear` encodes the computation $y = x A^T + b$, where the parameters passed to the `nn.Linear` constructor is the *shape* of $A$.

In [2]:
m = nn.Linear(20,30)

In [3]:
input = torch.randn(128, 20)

In [None]:
output = m(input)
output.size()

In [4]:
input_names = [ "input_0" ]
output_names = [ "output_0" ]

torch.onnx.export(m, input, 'trivial.onnx', verbose=True, # `verbose` outputs to stdout
                  input_names=input_names, output_names=output_names)

graph(%input_0 : Float(128, 20, strides=[20, 1], requires_grad=0, device=cpu),
      %weight : Float(30, 20, strides=[20, 1], requires_grad=1, device=cpu),
      %bias : Float(30, strides=[1], requires_grad=1, device=cpu)):
  %output_0 : Float(128, 30, strides=[30, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%input_0, %weight, %bias) # /opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:1753:0
  return (%output_0)



In [2]:
class Trivial(nn.Module):
    def __init__(self):
        super(Trivial, self).__init__()
        self.lin1 = nn.Linear(1000, 100)
        self.lin2 = nn.Linear(100, 10)
    def forward(self, input):
        x = self.lin1(input)
        y = F.relu(x)
        return self.lin2(y)

In [3]:
input_names = [ "input_0" ]
output_names = [ "output_0" ]

t = Trivial()
input = torch.randn(1000)

torch.onnx.export(t, input, 'trivial.onnx', input_names=input_names, output_names=output_names)

In [5]:
deser_model = onnxmltools.load_model("trivial.onnx")

In [6]:
type(deser_model)

onnx.onnx_ml_pb2.ModelProto

In [7]:
meta = deser_model.metadata_props.add()
meta.key = "version"
meta.value = "0.0.1"
deser_model.metadata_props[0]

key: "version"
value: "0.0.1"

In [9]:
# assuming 'testfile.md' exists
post = fm.Frontmatter.read_file('testfile.md')

  "attributes": yaml.load(fmatter),


In [10]:
# Read a dictionary
post['attributes']
#print(post['body'], "\n")        # String
#print(post['frontmatter'])       # String

{'title': 'Third Post', 'date': 'Oct 8, 2018 5:26pm PST'}

In [11]:
with open('testfile.md', 'r') as f:
    mdtxt = f.read()

In [14]:
modelcard_meta = deser_model.metadata_props.add()

In [15]:
modelcard_meta.key = "model_card"
modelcard_meta.value = mdtxt

In [16]:
onnxmltools.utils.save_model(deser_model, "trivial_1.onnx")