Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model converted to TFLite always returns NaN as output. #22803

Closed
sercant opened this issue Oct 7, 2018 · 24 comments
Closed

Model converted to TFLite always returns NaN as output. #22803

sercant opened this issue Oct 7, 2018 · 24 comments
Assignees
Labels
comp:lite TF Lite related issues

Comments

@sercant
Copy link

sercant commented Oct 7, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS 10.13.6
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: --
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v1.10.0-12-g4dcfddc5d1 1.10.1
  • Python version: 3.6.5
  • Bazel version (if compiling from source): --
  • GCC/Compiler version (if compiling from source): --
  • CUDA/cuDNN version: --
  • GPU model and memory: Intel Iris Plus Graphics 650 1536 MB
  • Exact command to reproduce: python3 test.py

Describe the problem

I have been trying to convert a frozen graph trained using this repo for using on android with TFLite. Trained model uses MobileNetV2 as frontend and Mobile UNet for Semantic Segmentation as the model. The problem I am facing is: the frozen pb graph segments the image correctly but TFLite converted model returns all nan for the output. To try the problem I wrote the following script. The model is converted without any errors or warnings, but the output is not correct. Do you have any idea what might be causing this?

Note: converted model is also returning NaNs on android device.

Frozen graph: output_graph.pb

Source code / logs

test.py

import tensorflow as tf
import numpy as np
import cv2
from tensorflow.python.platform import gfile
from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes

sess = tf.Session()

# load graph
with gfile.FastGFile('output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')

# get tensors
input_tensor = sess.graph.get_tensor_by_name('Placeholder:0')
output_tensor = sess.graph.get_tensor_by_name('logits/Conv2D:0')

# generate random image
input_image = np.array(np.random.random_sample(
    [1, 128, 128, 3]), dtype=np.float32)

# run the model with tf
output_image = sess.run(output_tensor, feed_dict={input_tensor: input_image})

# print tf output
print('--- Tensorflow output ---')
print(output_image)
print('-------------------------')

# set shapes
input_tensor.set_shape([1, 128, 128, 3])
output_tensor.set_shape([1, 128, 128, 32])

# convert model
converter = tf.contrib.lite.TocoConverter.from_session(
    sess, [input_tensor], [output_tensor])
tflite_model = converter.convert()

# Prepare interpreter
interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# set input data
interpreter.set_tensor(input_details[0]['index'], input_image)

# run model on interpreter
interpreter.invoke()

# retrive output
output_data = interpreter.get_tensor(output_details[0]['index'])

# print tflite output
print('--- TFLite output ---')
print(output_data)
print('---------------------')

output

--- Tensorflow output ---
[[[[-14.484754   -14.454916    -3.9344878  ... -10.294399
     -2.837898    -8.190185  ]
   [-14.120294   -10.590508    -4.032942   ...  -6.7745924
     -0.4497184   -9.78646   ]
   [-14.561665   -10.49988     -8.065053   ...  -7.422716
     -0.7991432  -10.160792  ]
   ...
   [-13.12197     -7.3976164   -7.1669674  ...  -9.533363
     -2.0361094  -10.951963  ]
   [-15.041047    -7.3879066   -6.724542   ... -11.897878
     -2.1202648  -13.670592  ]
   [-14.483544   -10.037312    -6.356632   ... -12.075281
     -2.2860763  -10.284541  ]]

  [[-10.372202   -13.09114     -3.6517806  ...  -7.623592
     -1.8009435   -6.817739  ]
   [-10.72727    -10.886565    -5.621975   ...  -7.8185344
     -1.4768337  -10.389865  ]
   [-11.611484   -10.158413    -7.931344   ...  -4.938987
     -0.23626254  -8.830031  ]
   ...
   [-12.590868    -6.102834   -10.619679   ...  -9.990441
     -1.0927511  -10.764243  ]
   [-12.30341     -4.7649236   -6.600345   ...  -9.458132
     -0.8608778  -12.198781  ]
   [-11.649162    -6.2056537   -5.922945   ... -10.207803
     -1.5887291   -9.819743  ]]

  [[-11.40545    -13.755798    -6.9160714  ... -11.7735195
     -3.3357754  -11.139454  ]
   [-11.398698   -11.785369    -6.5561953  ...  -9.794318
     -2.8272014  -11.654141  ]
   [ -9.548821    -7.3276024   -8.640192   ...  -4.349879
      0.14261375  -7.0007625 ]
   ...
   [-12.497658    -5.8748426   -9.083981   ...  -9.841493
     -1.4732579  -11.357761  ]
   [-14.517144    -5.2391934   -8.496638   ... -10.834668
     -2.6033173  -13.944796  ]
   [-14.292226    -7.0837607   -6.3621516  ... -10.551426
     -3.6190045  -12.224428  ]]

  ...

  [[ -6.1242228  -14.730902    -6.034355   ...  -5.2220926
     -1.1160429   -2.2097938 ]
   [ -5.003286   -16.216772    -5.28262    ...  -5.2270694
     -1.7447093   -4.245701  ]
   [ -5.595118   -15.978978    -4.214302   ...  -5.4203877
     -1.8398296   -4.396698  ]
   ...
   [-13.178917   -13.012176   -10.450902   ... -15.064126
     -1.9914117   -9.5184765 ]
   [-10.992667    -8.671063    -6.456934   ... -14.054223
     -1.4051182   -9.887496  ]
   [ -9.728466   -10.335494    -7.3331285  ... -10.754501
     -1.7173084   -4.671226  ]]

  [[ -5.4983754  -15.449182    -5.7204423  ...  -4.4113154
     -1.0589103   -2.6990566 ]
   [ -5.384841   -16.741693    -5.5674496  ...  -5.684756
     -1.8891927   -4.65452   ]
   [ -5.7909193  -16.244637    -4.5293765  ...  -6.4048567
     -2.3706574   -4.982708  ]
   ...
   [-10.004818   -11.296059    -7.158481   ... -10.9329
     -2.0753372   -8.129092  ]
   [ -7.942011    -8.787835    -2.8869028  ... -10.7461605
     -1.7351687   -7.8243003 ]
   [ -9.368582   -11.195904    -5.3443894  ...  -8.967132
     -1.5083878   -5.205722  ]]

  [[ -7.6940765  -15.492795    -4.6488175  ...  -5.7006836
     -1.3711176   -3.7699785 ]
   [ -5.243174   -15.9268875   -5.07713    ...  -3.642994
     -1.4748344   -4.1258245 ]
   [ -4.8627806  -13.911514    -4.372596   ...  -2.4015875
     -1.4164882   -3.6560988 ]
   ...
   [ -9.049875   -12.410313    -5.53057    ...  -8.292001
     -2.442209    -4.6609883 ]
   [ -7.18582    -11.061987    -3.3339026  ...  -7.413499
     -2.0413182   -5.4470387 ]
   [ -9.58725    -13.576278    -5.9882216  ...  -8.204617
     -2.0788593   -5.216848  ]]]]
-------------------------
--- TFLite output ---
[[[[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]

  [[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]

  [[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]

  ...

  [[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]

  [[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]

  [[nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   ...
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]
   [nan nan nan ... nan nan nan]]]]
---------------------
@ymodak ymodak added the comp:lite TF Lite related issues label Oct 7, 2018
@ymodak ymodak assigned gargn and unassigned ymodak Oct 10, 2018
@hubert0527
Copy link

I faced the same problem too!

After days of investigation, I found the problem is caused by batch norm.
The values of feature maps significantly increases (around x100, and sometimes x1e+35) each time passing through the batch norm layer (either slim.batch_norm or tf.nn.fused_batch_norm). Eventually causes the values to become inf or nan (and only nan shown in final output).

I'm not sure if this is a problem for specific tensorflow version?
This problem happens to me for both TF 1.11.0 and TF-nightly.

@hazirbas
Copy link

hazirbas commented Nov 6, 2018

Hi @hubert0527 and @sercant,
Did you guys find any solution to this problem so far?

@sercant
Copy link
Author

sercant commented Nov 7, 2018

Hi @hazirbas,
Not on my end, unfortunately.

@dimitree54
Copy link

dimitree54 commented Jan 15, 2019

@hubert0527 Thank you for pointing at batch normalization: when I had removed it from network my outputs became normal (not nan). But of course the quality of my network fell drammatically without batch normalization. I tried to replace tf.layer.batch_normalization with tf.keras.layers.BatchNormalization and tf.contrib.layers.batch_norm, but no effect. Finally I solved the problem by implementing my own batch normalization like this:

    def my_moments(input_tensor):
        mean = tf.reduce_mean(input_tensor, axis=[0, 1, 2])
        dev = input_tensor - mean
        dev = dev * dev
        dev = tf.reduce_mean(dev, axis=[0, 1, 2])
        return mean, dev
    def my_bn(input_tensor):
        mu = tf.Variable(tf.ones(input_tensor.shape[3]))
        beta = tf.Variable(tf.zeros(input_tensor.shape[3]))
        mean, dev = my_moments(input_tensor)
        return beta + mu * (input_tensor - mean) / (tf.sqrt(dev) + 0.001)

Note that this is not literal implementation of batch norm (here moving average is not used), because only train mode was required for my project. Also note that we cannot use tf.nn.moments to calc mean and dev because it is not supported by tflite (so we need to implement own function for moments). After replacing batch normalization with provided functions I was able to train my network, export it to tflite and use it during inference in tflite correctly.

@sercant
Copy link
Author

sercant commented Jan 16, 2019

@hubert0527 Thank you for pointing at batch normalization: when I had removed it from network my outputs became normal (not nan). But of course the quality of my network fell drammatically without batch normalization. I tried to replace tf.layer.batch_normalization with tf.keras.layers.BatchNormalization and tf.contrib.layers.batch_norm, but no effect. Finally I solved the problem by implementing my own batch normalization like this:

    def my_moments(input_tensor):
        mean = tf.reduce_mean(input_tensor, axis=[0, 1, 2])
        dev = input_tensor - mean
        dev = dev * dev
        dev = tf.reduce_mean(dev, axis=[0, 1, 2])
        return mean, dev
    def my_bn(input_tensor):
        mu = tf.Variable(tf.ones(input_tensor.shape[3]))
        beta = tf.Variable(tf.zeros(input_tensor.shape[3]))
        mean, dev = my_moments(input_tensor)
        return beta + mu * (input_tensor - mean) / (tf.sqrt(dev) + 0.001)

Note that this is not literal implementation of batch norm (here moving average is not used), because only train mode was required for my project. Also note that we cannot use tf.nn.moments to calc mean and dev because it is not supported by tflite (so we need to implement own function for moments). After replacing batch normalization with provided functions I was able to train my network, export it to tflite and use it during inference in tflite correctly.

Have you tried using tf.layers.batch_normalization? I have another network and it seems to be working on tensorflow v1.12.0.

@dimitree54
Copy link

@hubert0527 Thank you for pointing at batch normalization: when I had removed it from network my outputs became normal (not nan). But of course the quality of my network fell drammatically without batch normalization. I tried to replace tf.layer.batch_normalization with tf.keras.layers.BatchNormalization and tf.contrib.layers.batch_norm, but no effect. Finally I solved the problem by implementing my own batch normalization like this:

    def my_moments(input_tensor):
        mean = tf.reduce_mean(input_tensor, axis=[0, 1, 2])
        dev = input_tensor - mean
        dev = dev * dev
        dev = tf.reduce_mean(dev, axis=[0, 1, 2])
        return mean, dev
    def my_bn(input_tensor):
        mu = tf.Variable(tf.ones(input_tensor.shape[3]))
        beta = tf.Variable(tf.zeros(input_tensor.shape[3]))
        mean, dev = my_moments(input_tensor)
        return beta + mu * (input_tensor - mean) / (tf.sqrt(dev) + 0.001)

Note that this is not literal implementation of batch norm (here moving average is not used), because only train mode was required for my project. Also note that we cannot use tf.nn.moments to calc mean and dev because it is not supported by tflite (so we need to implement own function for moments). After replacing batch normalization with provided functions I was able to train my network, export it to tflite and use it during inference in tflite correctly.

Have you tried using tf.layers.batch_normalization? I have another network and it seems to be working on tensorflow v1.12.0.

Yes, I tried tf.layers.batch_normalization. It was my original normalization function which I started my investigation from. And I also was using tensorflow v1.12 when the problem was faced.

@miaout17 miaout17 assigned karimnosseir and unassigned gargn Mar 19, 2019
@miaout17
Copy link
Contributor

Load balancing... @karimnosseir could you take a look?

@IrvingMeng
Copy link

@sercant When freezing a model for inference, the attribute "is_training" of the BN layers should be set as false. In your frozen model, "is_training" is true. That makes the means/variances of BN layers to be all 0s.
Maybe you should regenerate a frozen model with "is_training=false" and then convert it to a tflite model.

@karimnosseir
Copy link
Contributor

Hi Sercant,

As IrvingMeng mentioned, you need to make sure is_training is false. Can you please confirm and share tflite &pb files for the model

Thanks

@dimitree54
Copy link

@IrvingMeng Thank you, that really fixes the problem.

In my case I was using training=True during export intentionally because that gives better performance on my task. It worked fine during pb inference, so I was surprised that tflite was not working. But now I see that batch normalization in tflite works correct when using as intended. And for my task I should use custom normalization, for example suggested above.

@ricardobnjunior
Copy link

ricardobnjunior commented Apr 17, 2019

@sercant Have you found some solution for this problem? I have the same problem :(

@sercant
Copy link
Author

sercant commented Apr 21, 2019

Thank you @IrvingMeng for pointing out the issue with the batch normalization parameter.

@sercant Have you found some solution for this problem? I have the same problem :(

@ricardobnjunior Have you tried the solution suggested by @IrvingMeng ? I, unfortunately, don't work on this issue anymore since I initially opened it in October and got the answer in March. If you can confirm that the suggested solution works, the issue can be closed.

@ablenesi
Copy link
Contributor

Hello,

I have been playing around with the pix2pix example:
https://colab.research.google.com/drive/1Uv5QwTcygHgrD0e3AO1L_uwyuYqICrVJ

If you check out the link you can see setting trainable=False for BatchNormalization does not help.
But I can confirm that the issue is with BatchNormalization since when those layers are removed the conversion runs as expected.

@IrvingMeng
Copy link

@ablenesi I tested my solution with the tf.layers.batch_normalization from tensorflow v1 (following the steps in this page https://www.tensorflow.org/lite/guide/get_started ). However, it won't work with the model built with keras layers.
I encountered issue #24591 when freezing a tf model built with keras layers.
Please let me know if you find a solution for the keras modesl. :)

@prathamesh0
Copy link

@ablenesi Hi, did you find any workaround?

@wangrui261
Copy link

I find a solution about this question. Just change the attribute fused of tf.keras.layers.bn to false. The output is correct.

@MeghnaNatraj MeghnaNatraj self-assigned this Aug 10, 2020
@MeghnaNatraj
Copy link
Member

Marking issue as resolved due to inactivity. Feel free to re-open this if it's unresolved or file a new issue

@javirk
Copy link

javirk commented Aug 18, 2020

Hi, I'm not sure if I should be filing a different issue or it's ok to re-open this one. I am facing the same issue with nan values after converting pix2pix' generator to tflite. As @wangrui261 said, I replaced all tf.keras.layers.BatchNormalization() to tf.keras.layers.BatchNormalization(fused=False, trainable=False), but the output is still nan. I tried removing all batch normalization layers and, even though the output is wrong, it has values.

I am using tensorflow version 2.3.0 and Python 3.7.7

Has anyone been able to solve this or has found a workaround?

@MeghnaNatraj
Copy link
Member

Re-opening unresolved issue.

@MeghnaNatraj MeghnaNatraj reopened this Aug 18, 2020
@javirk
Copy link

javirk commented Aug 21, 2020

I discovered this generator also returns NaN values when running on CPU. As @hubert0527 said, this is due to big values of feature maps after Batch Normalization layers. I found somewhere in stack overflow that GPU zeroes these NaN values, while CPU keeps them, so I added a Lambda layer with a tf.where to solve this in CPU: tf.keras.layers.Lambda(lambda x: tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)) and CPU and GPU executions return almost the same values (1e-6 maximum error in my case).

I have the feeling that this would fix the TFLite conversion, but when calling allocate_tensors on the new TFLite model I get the following error: Encountered unresolved custom op: IsNan.Node number 3 (IsNan) failed to prepare..
I have never added a new operation to TFLite and thus I don't know how to proceed (the official documentation assumes too much previous knowledge for me). Is there any "novice" documentation that I can look in order to add IsNan custom operation to TFLite ?

@zhenzey
Copy link

zhenzey commented Dec 9, 2020

@javirk: I have the same issue as you do. Do you find any solution for transferring pix2pix GAN to tf lite model? Thank you so much!

@karimnosseir
Copy link
Contributor

@javirk One way is to use TF SELECT which will fallback to TF kernel for the missing TF ops
https://www.tensorflow.org/lite/guide/ops_select
Another way is to write a TFLite custom op
https://www.tensorflow.org/lite/guide/ops_custom

@karimnosseir karimnosseir removed their assignment Dec 9, 2020
@paneda1998
Copy link

I still have this problem.
after removing all bn layers output is still Nan!
whether on Cpu or Gpu!

@paneda1998
Copy link

paneda1998 commented Dec 29, 2020

Finally I solve the problem by quantizing deep model...
probably this problem appears on devices with low processing power
(my dev is samsung A50)
this is the code:
converter = tf.lite.TFLiteConverter.from_keras_model(self.deep_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS # enable TensorFlow Lite ops.
]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues
Projects
None yet
Development

No branches or pull requests