In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import torch.nn.functional as F

In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()

        self.input = nn.Linear(5,30)
        self.hidden = nn.Linear(30,10)
        self.out = nn.Linear(10,1)

    def forward(self, x):
        x = F.leaky_relu(self.input(x))
        x = F.leaky_relu(self.hidden(x))
        x = F.leaky_relu(self.out(x))
        return torch.sigmoid(x)

model = SimpleNet()

In [None]:
model

### Training data - random values with target 1 if sum of input values is greater than 2, 0 otherwise

In [None]:
x = torch.rand(1000,5)
x[:5]

In [None]:
y = (x.sum(1)>2).float().view(-1,1)
y[:5]

In [None]:
dataset = TensorDataset(x,y)
train_loader = DataLoader(dataset, batch_size=32)

### Training

In [None]:
optim = Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.ExponentialLR(optim, gamma=0.95)
criterion = nn.BCELoss()
epoch = 50

In [None]:
for epoch in range(1000):
    for idx, (train_x, train_label) in enumerate(train_loader):
        train_x = train_x
        train_label = train_label
    #     label_np = np.zeros((train_label.shape[0], 10))
        optim.zero_grad()
        predict_y = model(train_x)
        _error = criterion(predict_y, train_label)
        _error.backward()
        optim.step()
    if epoch % 5 == 0:
        print('epoch:{}, idx: {}, loss: {}'.format(epoch, idx, _error))

In [None]:
dummy_input = torch.tensor([0.3,1,1.2,0.3,1])
model(dummy_input)

In [None]:
torch.save(model,"Simple_example.torch")

### Saving as onnx

In [None]:
torch.onnx.export(model,dummy_input,"Simple_example.onnx",input_names=["input"],output_names=["out"])

## Testing with onnx

In [None]:
import onnx

In [None]:
model_onnx = onnx.load("Simple_example.onnx")

In [None]:
onnx.checker.check_model(model_onnx)

In [None]:
print(onnx.helper.printable_graph(model_onnx.graph))

### Application with onnxruntime

In [None]:
import onnxruntime as ort

In [None]:
ort_session = ort.InferenceSession('Simple_example.onnx')

In [None]:
example = np.array([0.3,1,1.2,0.3,1]).astype(np.float32)
outputs = ort_session.run(None, {'input': example})

In [None]:
outputs