Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

different output value in pytorch->onnx->tflite(int8 quantization) #52357

Closed
JunhoohnuJ opened this issue Oct 13, 2021 · 19 comments
Closed

different output value in pytorch->onnx->tflite(int8 quantization) #52357

JunhoohnuJ opened this issue Oct 13, 2021 · 19 comments
Assignees
Labels
comp:lite TF Lite related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.11 Issues related to TF 2.11 type:performance Performance Issue

Comments

@JunhoohnuJ
Copy link

Please make sure that this is an issue related to performance of TensorFlow.
As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): tensorflow:2.5.0-gpu docker
  • TensorFlow version (use command below): 2.5.0
  • Python version: 3.6.9
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.2/ 8.1.0
  • GPU model and memory: RTX 3090

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
I convert resnet50 pytorch -> onnx -> tflite with int8 quantization.
output value validation between pytorch <-> onnx, pytorch <-> pb, pytorch <-> tflite, pb <-> tflite
input is same image with size 256, check output value "np.testing.assert_allclose(output1, output2, rtol=1e-3, atol=1e-05)"
(using tflite interpreter only when i inference tflite "https://www.tensorflow.org/lite/guide/python?hl=ko")

Max absolute difference: 0.00076199 in pytorch <-> onnx
Max absolute difference: 0.00112534 in pytorch <-> pb
Max absolute difference: 13.387602 in pytorch <-> tflite(quantized)
Max absolute difference: 13.387438 in pb <-> tflite(quantized)
it's same max absolute difference between tflite(no quantized) and something(pytorch, onnx, pb)
ex) 0.0076~ in pytorch <-> tflite(no quant), 0.0011~ in pytorch <-> tflite(no quant)
i don't know why occur this difference

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.
pb to tflite log
2021-10-13 09:18:56.162936: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-10-13 09:18:57.485452: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-10-13 09:18:57.511230: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.511916: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties:
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.68GiB deviceMemoryBandwidth: 871.81GiB/s
2021-10-13 09:18:57.511955: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-10-13 09:18:57.513717: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-10-13 09:18:57.513767: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-10-13 09:18:57.514354: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2021-10-13 09:18:57.514537: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2021-10-13 09:18:57.515198: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2021-10-13 09:18:57.515720: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2021-10-13 09:18:57.515866: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-10-13 09:18:57.515918: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.516398: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.516976: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-10-13 09:18:57.517199: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-13 09:18:57.517766: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.518224: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties:
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.68GiB deviceMemoryBandwidth: 871.81GiB/s
2021-10-13 09:18:57.518272: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.518814: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.519243: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-10-13 09:18:57.519268: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-10-13 09:18:57.810376: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-10-13 09:18:57.810410: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264] 0
2021-10-13 09:18:57.810420: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0: N
2021-10-13 09:18:57.810591: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.811162: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.811684: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:57.812186: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21512 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6)
2021-10-13 09:18:58.498192: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:345] Ignored output_format.
2021-10-13 09:18:58.498225: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:348] Ignored drop_control_dependency.
2021-10-13 09:18:58.498234: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored change_concat_input_ranges.
2021-10-13 09:18:58.498881: I tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: backbone_saved_model/
2021-10-13 09:18:58.515289: I tensorflow/cc/saved_model/reader.cc:90] Reading meta graph with tags { serve }
2021-10-13 09:18:58.515331: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: backbone_saved_model/
2021-10-13 09:18:58.515383: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-10-13 09:18:58.515393: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]
2021-10-13 09:18:58.527926: I tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2021-10-13 09:18:58.546224: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3699850000 Hz
2021-10-13 09:18:58.563849: I tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: backbone_saved_model/
2021-10-13 09:18:58.577967: I tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 79088 microseconds.
2021-10-13 09:18:58.657933: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY to enable.
2021-10-13 09:18:58.675431: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:58.675985: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties:
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.68GiB deviceMemoryBandwidth: 871.81GiB/s
2021-10-13 09:18:58.676068: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:58.676635: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:58.677107: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-10-13 09:18:58.677148: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-10-13 09:18:58.677157: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264] 0
2021-10-13 09:18:58.677165: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0: N
2021-10-13 09:18:58.677253: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:58.677779: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-13 09:18:58.678280: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21512 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6)
fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0

@JunhoohnuJ JunhoohnuJ added the type:performance Performance Issue label Oct 13, 2021
@sushreebarsa sushreebarsa added comp:lite TF Lite related issues TF 2.5 Issues related to TF 2.5 labels Oct 13, 2021
@sushreebarsa
Copy link
Contributor

@JunhoohnuJ
In order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Thanks!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Oct 13, 2021
@JunhoohnuJ
Copy link
Author

JunhoohnuJ commented Oct 13, 2021

#1. """make input"""

import numpy as np
np.save('input.npy', np.random.rand(1, 3, 256, 256))
"""end make"""

#2. """pytorch to onnx"""

import torch
from torch import nn
from torchvision.models import resnet50

model = resnet50(pretrained=True)
torch_model = nn.Sequential(*[model.layer1, model.layer2, model.layer3, model.layer4])
torch_model = torch_model.cuda()

x = np.load('input.npy')
x = torch.tensor(x, dtype=torch.float32).cuda()
with torch.no_grad():
coord = torch_model(x)
out = coord.detach().cpu().numpy()
np.save('output_pytorch.npy', out)

torch.onnx.export(torch_model,
x,
"backbone.onnx",
export_params=True,
opset_version=12,
input_names=['input'],
output_names=['output'],
keep_initializers_as_inputs=True,
# dynamic_axes={'input': {0: 'batch_size'}
)

#3. """onnx runtime"""
import onnxruntime
import onnxruntime.tools
import onnxruntime as ort

x = np.load('input.npy')
output_torch = np.load('output_pytorch.npy')
ort.set_default_logger_severity(0)
options = ort.SessionOptions()
ort_session = onnxruntime.InferenceSession("backbone.onnx", options)
ort_inputs = {ort_session.get_inputs()[0].name: x}
onnxOut = ort_session.run(None, ort_inputs)
np.save('output_onnx.npy', onnxOut[0])
np.testing.assert_allclose(output_pytorch, onnxOut[0], rtol=1e-3, atol=1e-05)

#4. """onnx to pb"""
import numpy as np
import onnx
import tensorflow as tf
from onnx_tf.backend import prepare
import os

onnx_model = onnx.load("backbone.onnx")
tf_model_path = "backbone_saved_model"

tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
print('saved model')

model = tf.saved_model.load(tf_model_path)
model.trainable = False
print('load model')

x = np.load('input.npy')
out = model(**{'input': x})

output_pytorch = np.load('output_pytorch.npy')
np.save('output_pb.npy', out)
np.testing.assert_allclose(out, output_pytorch, rtol=1e-3, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

#5. """pb to tflite"""
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
import numpy as np
import cv2
import os
import torchvision.transforms as transforms
import time

input_shape = (3, 256, 256)

def representative_data_gen():
a = []
pixel_mean = (0.485, 0.456, 0.406)
pixel_std = (0.229, 0.224, 0.225)
# COCO dataset
datapath = "/nvme1/datasets/COCO/val2017/" ########### need to change
file_list = os.listdir(datapath)
for i in range(160):
img = cv2.imread(os.path.join(datapath, file_list[i]))
img = cv2.resize(img, (256, 256))
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=pixel_mean, std=pixel_std)])
img = transform(img).numpy()
img = img.astype(np.float32)
a.append(img)
a = np.array(a)
img = tf.data.Dataset.from_tensor_slices(a).batch(1).take(100)
for i in img.take(1):
print(i)
yield [i]

tf_model_path = "backbone_saved_model/"
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset=representative_data_gen
tflite_model = converter.convert()
tflite_model_path = "backbone.tflite"

with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

#6. """inference tflite"""
import tflite_runtime.interpreter as tflite
import numpy as np

pose_post_model_file = 'backbone.tflite'

interpreter = tflite.Interpreter(model_path=pose_post_model_file)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']

x = np.load('input.npy')
interpreter.set_tensor(input_details[0]['index'], x)

interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

output_pytorch = np.load('output_pytorch.npy')
np.save('output_tflite.npy', output_data)
np.testing.assert_allclose(output_data, output_pytorch, rtol=1e-3, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

each output is output_.npy( is pytorch, onnx, pb, tflite)

@sushreebarsa sushreebarsa removed the stat:awaiting response Status - Awaiting response from author label Oct 13, 2021
@r3krut
Copy link

r3krut commented Oct 18, 2021

Same issue with custom CenterNet model based on PyTorch. My pipeline for conversion is follows: PyTorch-> ONNX -> Keras -> TFLite. Interestingly, that error(Absolute and squared errors) between output tensors of PyTorch and Keras(after conversion) models is near zero. But, when conversion from Keras to TFLite is done the error is big.

@JunhoohnuJ
Copy link
Author

Same issue with custom CenterNet model based on PyTorch. My pipeline for conversion is follows: PyTorch-> ONNX -> Keras -> TFLite. Interestingly, that error(Absolute and squared errors) between output tensors of PyTorch and Keras(after conversion) models is near zero. But, when conversion from Keras to TFLite is done the error is big.

thank you for comment. Can you tell me how many errors occurred specificantly?

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 19, 2021
@nyadla-sys
Copy link
Member

@Xhark @liufengdb @ebrevdo @jianlijianli

I am seeing mobilenet_v2 tflite(quatized) that I generated from pytorch is not working as expected.
and float(32) model seems to be working fine.
May be I am doing something wrong with Post training quatization,could anyone of you look below Colab/github link and letme know.

//colab link
https://colab.research.google.com/github/nyadla-sys/pytorch_2_tflite/blob/main/pytorch_to_onnx_to_tflite(quantized)_with_imagedata.ipynb

//github link
https://github.com/nyadla-sys/pytorch_2_tflite/blob/main/pytorch_to_onnx_to_tflite(quantized)_with_imagedata.ipynb

@nyadla-sys
Copy link
Member

@Xhark @liufengdb @ebrevdo @jianlijianli : Gentle reminder.!

@Hastyrush
Copy link

Same issue with a Resnet based backbone feature extractor model initialized and trained from timm in pytorch. Pytorch > Onnx > Tensorflow PB with all minimal accuracy loss. Tflite FP32 and Dynamically quantized INT8 works fine as well, but statically quantized INT8 model totally fails with 0AP

@mohantym
Copy link
Contributor

mohantym commented Oct 6, 2022

@JunhoohnuJ ! @nyadla-sys !
I could replicate this issue in the 2.11 version. Attached gist for reference.
Thank you!

@mohantym mohantym added TF 2.10 and removed TF 2.5 Issues related to TF 2.5 labels Oct 6, 2022
@pjpratik pjpratik added TF 2.11 Issues related to TF 2.11 and removed TF 2.10 labels Jan 12, 2023
@mikel-brostrom
Copy link

mikel-brostrom commented Mar 13, 2023

Same issue for all the YOLOX family models. TFLite FP32 and FP16 works perfectly. But the output discrepancies are large between the tflite models and the quantized ones to the degree that they are unusable.

Model size mAPval
0.5:0.95
mAPval
0.5
YOLOX-nano PyTorch (original model) 416 0.256 0.411
YOLOX-nano ONNX 416 0.256 0.411
YOLOX-nano TFLite FP16 416 0.256 0.411
YOLOX-nano TFLite FP32 416 0.256 0.411
YOLOX-nano TFLite full_integer_quant 416 0 0
YOLOX-nano TFLite dynamic_range_quant 416 0 0
YOLOX-nano TFLite integer_quant 416 0 0

@PINTO0309
Copy link
Contributor

PINTO0309 commented Mar 19, 2023

YOLOX's INT8 outputs now match within almost an acceptable margin of error. Sorry for the trouble. Therefore, I have deleted my post immediately preceding the misleading one. It is an issue of ML model structure.

A quote from an issue of a conversion tool I am creating. Thus, differences in the route of conversion were not related to accuracy degradation. SiLU (Swish) was found to significantly degrade the accuracy of the model during quantization. As an additional research reference, HardSwish also seems to cause significant accuracy degradation during quantization, as does SiLU (Swish). Activation functions that cannot compute the quantization range cleanly are incompatible.


It is a matter of model structure. The activation function, kernel size and stride for Pooling, and kernel size and stride for Conv should be completely revised. See: https://github.com/PINTO0309/onnx2tf/issues/244#issuecomment-1475128445

  • e.g. YOLOv8 https://docs.openvino.ai/latest/notebooks/230-yolov8-optimization-with-output.html

  • e.g. YOLOX-Nano https://github.com/TexasInstruments/edgeai-yolox

    Before After
    Swish/SiLU
    image
    ReLU
    image
    DepthwiseConv2D
    image
    Conv2D
    image
    MaxPool, kernel_size=5x5,9x9,13x13
    image
    MaxPool, kernel_size=3x3
    image
    ### Float32 - YOLOX-Nano
    (1, 52, 52, 85)
    array([[[
        [ 0.971787,  0.811184,  0.550566, ..., -5.962632, -7.403673, -6.735206],
        [ 0.858804,  1.351296,  1.231673, ..., -6.479690, -8.277064, -7.664936],
        [ 0.214827,  1.035119,  1.458006, ..., -6.291425, -8.229385, -7.761562],
            ...,
        [ 0.450116,  1.391900,  1.533354, ..., -5.672194, -7.121591, -6.880231],
        [ 0.593133,  2.112723,  0.968755, ..., -6.150078, -7.370633, -6.874294],
        [ 0.088263,  1.985220,  0.619998, ..., -5.507928, -6.914980, -6.234259]]]]),
    
    ### INT8 - YOLOX-Nano
    (1, 52, 52, 85)
    array([[[
        [ 0.941908,  0.770652,  0.513768, ..., -5.993958, -7.449634, -6.850238],
        [ 0.856280,  1.284420,  1.198792, ..., -6.507727, -8.391542, -7.792146],
        [ 0.256884,  0.941908,  1.455676, ..., -6.336471, -8.305914, -7.877774],
            ...,
        [ 0.342512,  1.370048,  1.541304, ..., -5.737075, -7.192750, -7.107122],
        [ 0.513768,  2.226327,  1.027536, ..., -6.165215, -7.449634, -7.021494],
        [ 0.085628,  2.055072,  0.685024, ..., -5.480191, -7.021494, -6.422099]]]]),
    

@PINTO0309
Copy link
Contributor

No, it is more of a SiLU issue.

@mikel-brostrom
Copy link

mikel-brostrom commented Mar 20, 2023

Will try to replace all the SiLU modules by ReLU. Should be straight forward. Then retrain of course...

@PINTO0309
Copy link
Contributor

image

@mikel-brostrom
Copy link

mikel-brostrom commented Mar 21, 2023

Btw, what is the Y axis @PINTO0309 ?

@PINTO0309
Copy link
Contributor

It is simply the number of data samples.

https://github.com/PINTO0309/onnx2tf#6-if-the-accuracy-of-the-int8-quantized-model-degrades-significantly

https://gist.github.com/motokimura/1a90c0b8c5628914b99a81cd91369636

@sachinprasadhs
Copy link
Contributor

If your issue is resolved, could you please close this issue. Thanks!

@sachinprasadhs sachinprasadhs self-assigned this Apr 27, 2023
@sachinprasadhs sachinprasadhs added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Apr 27, 2023
@github-actions
Copy link

github-actions bot commented May 5, 2023

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 5, 2023
@github-actions
Copy link

This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.11 Issues related to TF 2.11 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests