In [None]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import torch.nn.functional as F

from root_utils import tree_to_pandas
ROOT_DATA_PATH = "AnalysisResults_trees.root"
TRAIN_DATA = "train_params.csv"

MODEL_PATH = "../models/"
MODEL_NAME = "Simple_example_multioutput"
ONNX_MODEL_PATH = "%s/%s.onnx" % (MODEL_PATH, MODEL_NAME)
batch_size = 32
data_size = 1024
classes = 5

In [None]:
training_data = tree_to_pandas(ROOT_DATA_PATH, "DF_2955850012345678000/O2pidtracks", "*")
print(training_data)
training_data.to_csv(TRAIN_DATA, sep=";")

In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()

        self.input = nn.Linear(5,15)
        self.hidden = nn.Linear(15,10)
        self.out = nn.Linear(10,classes)


    def forward(self, x):
        x = torch.tanh(self.input(x))
        x = torch.tanh(self.hidden(x))
        x = torch.softmax(self.out(x), 0) # No softmax() if we use CrossEntropyLoss()
        return x

model = SimpleNet()

In [None]:
model

### Training data - random values with target 1 if sum of input values is greater than 2, 0 otherwise

In [None]:
x = torch.rand(data_size,5)
x[:5]

In [None]:
y = torch.floor(x.sum(1)).long().view(-1,1)
print(y[:5])

# One hot encoding buffer that you create out of the loop and just keep reusing
y_onehot = torch.FloatTensor(batch_size, classes)

In [None]:
dataset = TensorDataset(x,y)
train_loader = DataLoader(dataset, batch_size=batch_size)

### Training

In [None]:
optim = Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.ExponentialLR(optim, gamma=0.95)
criterion = nn.MSELoss()
epoch = 50

In [None]:
for epoch in range(1000):
    for idx, (train_x, train_label) in enumerate(train_loader):
        # Converting labels to one-hot
        y_onehot.zero_()
        y_onehot.scatter_(1, train_label, 1)
    #     label_np = np.zeros((train_label.shape[0], 10))
        optim.zero_grad()
        predict_y = model(train_x)
        _error = criterion(predict_y, y_onehot)
        _error.backward()
        optim.step()
    if epoch % 5 == 0:
        print('epoch:{}, idx: {}, loss: {}'.format(epoch, idx, _error))

In [None]:
dummy_input = torch.tensor([0.3,1,1.2,0.3,1])
model(dummy_input)

In [None]:
torch.save(model, "%s/%s.torch" % (MODEL_PATH, MODEL_NAME))

### Saving as onnx

In [None]:
torch.onnx.export(model,dummy_input, ONNX_MODEL_PATH,input_names=["input"],output_names=["out"])

## Testing with onnx

In [None]:
import onnx

In [None]:
model_onnx = onnx.load(ONNX_MODEL_PATH)

In [None]:
onnx.checker.check_model(model_onnx)

In [None]:
print(onnx.helper.printable_graph(model_onnx.graph))

### Application with onnxruntime

In [None]:
import onnxruntime as ort

In [None]:
ort_session = ort.InferenceSession(ONNX_MODEL_PATH)
outputs_meta = ort_session.get_outputs()
outputs_meta[0].shape

In [None]:
example = np.array([0.3,1,1.2,0.3,1]).astype(np.float32)
outputs = ort_session.run(None, {'input': example})
#outputs = torch.FloatTensor(outputs)
#probabilities = torch.softmax(outputs, 1)

In [None]:
#probabilities
outputs