### Example 6: conditionals.
Here, we consider the following. Given an array `x`, we will return `1` if the first entry of `x` is `>=0`.

I.e. return `1` if `x[0] >= 0` - a related problem is if `x[0] > 0`.

For this problem, there is an easy solution:
1. keep a mask `x.ge(0)` and then pick up the first entry.
2. For the problem `x[0] > 0` use `x.gt(0)`.

In [1]:
onnx_model = "example6.onnx"

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class SampleNet(nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()
        self.m = 2
        self.c = 1
    
    def forward(self, x):
        b = torch.max(x)
        mask = x.ge(0).type(torch.int64)
        
        return self.m * mask[0] + self.c

In [4]:
model = SampleNet()
model.eval()

SampleNet()

In [5]:
# the warmup stage
x = torch.LongTensor([1, -2, 3, -4])
out = model(x)

In [6]:
out

tensor(3)

In [7]:
torch.onnx.export(
model,
x,  # warming up the model
onnx_model,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])

In [8]:
# uncomment to install netron.
#!pip install netron
import netron
netron.start(onnx_model, port=8084)

Serving 'example6.onnx' at http://localhost:8084


### use the onnx model

In [9]:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(onnx_model)

In [10]:
sess.get_inputs()[0].name

'input'

In [11]:
# check out the signature of sess.run - it has to have the output, then something like a feed_dict.

# either pass the feed dict directly.
# passing an array of any other size will result in an error.
outs = sess.run(['output'],
                       {
                         'input': np.array([22, 22, 22, 12], dtype=np.int64)  
                       })

print(outs)

[array(3, dtype=int64)]
