Skip to content

[Bug] MaxPool1d returned a float instead of -FLT_MAX #24819

Open
@Mohamad11Dab

Description

@Mohamad11Dab

Expected Behaviour

ONNXRuntime should be consistent with displaying -FLT_MAX,

Actual Behaviour

`––––– MISMATCH DETECTED –––––

Not equal to tolerance rtol=0.01, atol=0.001

x and y nan location mismatch:
x: array([[[ 7.828446e-01, -3.402823e+38, 3.303281e-01, ...,
9.437089e-01, 6.650881e-01, -3.402823e+38],
[ 6.911096e-01, 5.645327e-02, 1.617989e-01, ...,...
y: array([[[0.782845, nan, 0.330328, ..., nan, 0.665088,
nan],
[0.69111 , 0.056453, nan, ..., 0.88878 , nan,...`

1.617989e^-01 was displayed instead of -FLT_MAX

Reproduce

import random
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import torch
import torch.nn as nn
import torch.nn.functional as F
import tempfile
import onnx
import onnxruntime as ort
from numpy.testing import assert_allclose
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np
import nas_model_2

class SimpleBugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_conv = torch.nn.modules.conv.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
        self.block1 = nas_model_2.SqrtWrapper()
        self.block2 = torch.nn.modules.flatten.Flatten(start_dim=2)
        self.block3 = torch.nn.modules.pooling.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True)

    def forward(self, x):
        __input_conv = self.input_conv(x)
        __blocks__1 = self.block1(__input_conv)
        __blocks__2 = self.block2(__blocks__1)
        __blocks__3 = self.block3(__blocks__2)
        return __blocks__3

def main():
    seed = 811777723
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    model = SimpleBugModel()
    model.eval()
    dummy = torch.randn(1, 3, 32, 32, dtype=torch.float32)
    with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
        onnx_path = tmp.name
    torch.onnx.export(model, dummy, onnx_path, opset_version=19, input_names=['input'], output_names=['output'])

    ort_sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    ort_out = ort_sess.run(None, {'input': dummy.numpy()})[0]
    print('ORT output shape:', ort_out)

    onnx_model = onnx.load(onnx_path)

    shape_dict = {'input': dummy.numpy().shape}
    mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
    with tvm.transform.PassContext(opt_level=4):
        lib = relay.build(mod, target='llvm', params=params)
    m = graph_executor.GraphModule(lib['default'](tvm.cpu()))
    m.set_input('input', tvm.nd.array(dummy.numpy()))
    m.run()
    tvm_out = m.get_output(0)
    tvm_out = tvm_out.numpy()
    print('TVM output shape:', tvm_out)

    try:
        assert_allclose(ort_out, tvm_out, rtol=1e-2, atol=1e-3, equal_nan=True)
    except AssertionError as e:
        print('––––– MISMATCH DETECTED –––––')
        print(e)
    except Exception as e:
        print('––––– UNEXPECTED ERROR DURING COMPARISON –––––')
        print(f'{type(e).__name__}: {e}')

if __name__ == '__main__':
    main()
## nas_model_2
@basic_unit
class SqrtWrapper(nni_nn.Module):
    def forward(self, x):
        return torch.sqrt(x)

Versions

ONNXRuntime: 0.16.3
TVM: 0.17.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:tvmissues related to TVM execution provider

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions