Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Enable reference int8 fully connected with output shape with dimensions > 2 #54223

Conversation

SaoirseARM
Copy link
Contributor

Hi,

This PR enables reference int8 support for fully connected with output shape with dimensions > 2 in TensorFlow Lite.
Relevant issue which describes this is here: #53501

This PR also contains an update to the SimpleTestQuantizedOutputShape3DInt16 to ensure that output shape with dimensions >2 is verified.

Thanks,
Saoirse

@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Jan 31, 2022
@gbaned gbaned added comp:lite TF Lite related issues prtype:bugfix PR to fix a bug labels Feb 1, 2022
@gbaned gbaned requested a review from haozha111 February 1, 2022 13:16
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Feb 1, 2022
@haozha111 haozha111 requested review from jianlijianli and removed request for haozha111 February 2, 2022 18:07
@fredrec fredrec self-requested a review March 25, 2022 05:40
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Mar 25, 2022
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 25, 2022
@abattery
Copy link
Contributor

Could you add a e2e test case under the lite_v2_test.py?

@gbaned gbaned removed awaiting review Pull request awaiting review ready to pull PR ready for merge process labels Mar 25, 2022
@SaoirseARM SaoirseARM force-pushed the toupstream/fully_connected_dims_2 branch from ea13c59 to 4badb5f Compare March 25, 2022 17:22
@SaoirseARM SaoirseARM force-pushed the toupstream/fully_connected_dims_2 branch from 4badb5f to a8ddaaf Compare March 25, 2022 17:32
@SaoirseARM
Copy link
Contributor Author

Hi, Thanks for review. I have added end to end test to lite_v2_test.py. Please let me know if there is anything I need to add/change.

Best regards,
Saoirse

@@ -34,12 +34,13 @@ inline void FullyConnected(
const int32_t output_activation_min = params.quantized_activation_min;
const int32_t output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing: shouldn't this be TFLITE_DCHECK_GE(.., 2) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think this is ok, as the output shape can have 1 dimension? This is similar implementation to: 1e0716e#diff-2cb25dbe2523a0754af5b5057b800add3642578751a4083dd094b5b99ecdecb2

Best regards,
Saoirse

interpreter.allocate_tensors()
output_details = interpreter.get_output_details()

self.assertLen(output_details[0]['shape_signature'], 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add value comparisons between TF and TFLite models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review. I have added some extra comparisons. Please let me know if any more needs to be updated.

Best regards,

Saoirse

@gbaned gbaned requested a review from abattery March 29, 2022 14:48
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Mar 29, 2022
copybara-service bot pushed a commit that referenced this pull request Apr 1, 2022
@gbaned
Copy link
Contributor

gbaned commented Apr 1, 2022

Seems auto-merge is not happening but the changes are merged into master now, so we can close this. Thank you for the PR.

@gbaned gbaned closed this Apr 1, 2022
@google-ml-butler google-ml-butler bot removed the awaiting review Pull request awaiting review label Apr 1, 2022
mergify bot pushed a commit to tensorflow/tflite-micro that referenced this pull request Jan 18, 2023
3D output shape support for Fully_Connected layer under xtensa.

The fix to the reference kernels was made with tensorflow/tensorflow#54223 (and the corresponding import from TF Lite to TFLM). TFLM currently does not have a test case for this (corresponding to what was added with tensorflow/tensorflow#54223),

BUG=tensorflow/tensorflow#53501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues prtype:bugfix PR to fix a bug size:M CL Change Size: Medium
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants