### Example 7: conditionals.
This is a modification of `Example 6`. 

The setup is the same as that of `Example 6`, but here we will consider all the indices in the array instead of just one:
1. return `1` if `x[i] >= 0` for _all_ `i`.
2. return `1` if `x[i] > 0` for _all_ `i`.


As in the previous solution, we will maintain the `mask` array, and we will sum it up. If the sum 
reaches the size of the array, then all the values in the array are indeed `>=0`.
The solution for part (2) is similar.

In [1]:
onnx_model = "example7.onnx"

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class SampleNet(nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()
        self.m = 1
        self.c = 0
    
    def forward(self, x):
        b = torch.max(x)
        mask = x.ge(0).type(torch.int64)
        flag = mask.sum().eq(x.size(0)).type(torch.int64)
        return self.m * flag + self.c

In [4]:
model = SampleNet()
model.eval()

SampleNet()

In [5]:
# the warmup stage
x = torch.LongTensor([1, 2, 3, 4])
out = model(x)

In [6]:
out

tensor(1)

In [9]:
torch.onnx.export(
model,
x,  # warming up the model
onnx_model,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'], 
verbose=True)

graph(%input : Long(4:1)):
  %1 : Long() = onnx::Constant[value={0}]()
  %2 : Tensor = onnx::Less(%input, %1)
  %3 : Bool(4:1) = onnx::Not(%2) # <ipython-input-3-c9b0d20f8881>:9:0
  %4 : Long(4:1) = onnx::Cast[to=7](%3) # <ipython-input-3-c9b0d20f8881>:9:0
  %5 : Long() = onnx::ReduceSum[keepdims=0](%4) # <ipython-input-3-c9b0d20f8881>:10:0
  %6 : Tensor = onnx::Shape(%input)
  %7 : Tensor = onnx::Constant[value={0}]()
  %8 : Long() = onnx::Gather[axis=0](%6, %7) # <ipython-input-3-c9b0d20f8881>:10:0
  %9 : Bool() = onnx::Equal(%5, %8) # <ipython-input-3-c9b0d20f8881>:10:0
  %10 : Long() = onnx::Cast[to=7](%9) # <ipython-input-3-c9b0d20f8881>:10:0
  %11 : Long() = onnx::Constant[value={1}]()
  %12 : Long() = onnx::Mul(%10, %11)
  %13 : Long() = onnx::Constant[value={0}]()
  %output : Long() = onnx::Add(%12, %13)
  return (%output)



In [8]:
# uncomment to install netron.
#!pip install netron
import netron
netron.start(onnx_model, port=8085)

Serving 'example7.onnx' at http://localhost:8085


### use the onnx model

In [9]:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(onnx_model)

In [10]:
sess.get_inputs()[0].name

'input'

In [11]:
# check out the signature of sess.run - it has to have the output, then something like a feed_dict.

# either pass the feed dict directly.
# passing an array of any other size will result in an error.
outs = sess.run(['output'],
                       {
                         'input': np.array([22, -22, 22, 12], dtype=np.int64)  
                       })

print(outs) 

[array(0, dtype=int64)]
