### Example 5: `argmax` -- counts.

The setting of this example is the same as that of `Example 4`.
However, here, we want to _count_ the number of `max` elements in the array `x` and use 
that in the `forward` function.

In terms of the solution in `Example 4`, we will be considering `nz.size(0)` here.

As it turns out, here, `nz.size(0)` works. However, for earlier versions of the `onnxruntime` (even 
the previous version `1.2.0`), this would not work and we would need workarounds for this.

In [1]:
onnx_model = "example5.onnx"

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class SampleNet(nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()
        self.m = 2
        self.c = 1
    
    def forward(self, x):
        b = torch.max(x)
        mask = x.ge(b).type(torch.int64)
        nz = torch.nonzero(mask, as_tuple=False)
        return self.m * nz.size(0) + self.c

In [4]:
model = SampleNet()
model.eval()

SampleNet()

In [5]:
# the warmup stage
x = torch.LongTensor([1, 2, 3, 4])
out = model(x)

In [6]:
out

3

In [7]:
torch.onnx.export(
model,
x,  # warming up the model
onnx_model,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])

In [8]:
# uncomment to install netron.
#!pip install netron
import netron
netron.start(onnx_model, port=8080)

Serving 'example5.onnx' at http://localhost:8080


### use the onnx model

In [9]:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(onnx_model)

In [10]:
sess.get_inputs()[0].name

'input'

In [11]:
# check out the signature of sess.run - it has to have the output, then something like a feed_dict.

# either pass the feed dict directly.
# passing an array of any other size will result in an error.
outs = sess.run(['output'],
                       {
                         'input': np.array([22, 22, 22, 12], dtype=np.int64)  
                       })

print(outs) # should be 7, there are 3 max elements in the array

[array(7, dtype=int64)]
