### Example 4: `argmax` -- working solution.

As in `Example 2`, we
1. pass in an array to the inference.
2. Notice that the warmup example fixes the lengths of the array that the onnx model expects.
3. check out the netron visualization of the model.

In this example, we take an array `x`, and then take the `argmax` of the entries in `x`.
denote this index by `k`, then we return `m*k + c`.

Here, we will try a workaround: 
1. Compute the `max` of the array `x`. As we saw in `Example 2`, `onnxruntime` has `max` implemented
(`ReduceMax`).
2. Compute a `mask` consisting of the entries in the array that are `geq` than this `max`.
Note that this `mask` is an array of `0`s and `1`s. If the original array has a unique maximum, then 
this array will have exactly one `1`. 
3. If the nonzero array be called `nz`, then pick up the `argmax` as `nz[0][0]`. Note that the array `nz` has `size(1, k)` where `k` is the number of appearances of the `max` in the array.

**Note**:
1. If the `max` appears in the array `x` multiple times, then we pick up the _smallest_ index at which 
the `max` appears. 
If you want to pick up the _last_ appearance of the `max` in the array `x`, then that is simple too:
use `nz[-1][0]` instead of `nz[0][0]`.
2. Also, note that we use `torch.nonzero(mask, as_tuple=False)`; this is because of [this issue](https://github.com/pytorch/pytorch/issues/32994#issuecomment-629810935)

In [1]:
onnx_model = "example4.onnx"

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class SampleNet(nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()
    
    def forward(self, x):
        b = torch.max(x)
        mask = x.ge(b).type(torch.int64)
        nz = torch.nonzero(mask, as_tuple=False)
        return nz[0][0]

In [4]:
model = SampleNet()
model.eval()

SampleNet()

In [5]:
# the warmup stage
x = torch.LongTensor([1, 2, 3, 4])
out = model(x)  # the argmax

In [6]:
out

tensor(3)

In [7]:
torch.onnx.export(
model,
x,
onnx_model,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])

In [8]:
# uncomment to install netron.
#!pip install netron
import netron
netron.start(onnx_model, port=8083)

Serving 'example4.onnx' at http://localhost:8083


### use the onnx model

In [10]:
import onnxruntime as ort
import numpy as np
sess_options = ort.SessionOptions()
sess_options.enable_profiling = True
sess = ort.InferenceSession(onnx_model)

In [11]:
sess.get_inputs()[0].name

'input'

In [12]:
# check out the signature of sess.run - it has to have the output, then something like a feed_dict.

# either pass the feed dict directly.
# passing an array of any other size will result in an error.
outs = sess.run(['output'],
                       {
                         'input': np.array([11, 20, 22, 12], dtype=np.int64)  
                       })

print(outs)  # should be 2, since array[2] = 22

[array(2, dtype=int64)]
