- 
                Notifications
    You must be signed in to change notification settings 
- Fork 74.9k
Refactoring fully_connected to share code between reference and optimized kernels. #46242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring fully_connected to share code between reference and optimized kernels. #46242
Conversation
| Thanks for contributing to TensorFlow Lite Micro. To keep this process moving along, we'd like to make sure that you have completed the items on this list: 
 We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review. | 
…ized kernels. This change is currently for discussion and to figure out what parts of this refactor we like and what we do not. Also note that we have `#if !defined(XTENSA)` in fully_connected_common.cc because the linker appears to be failing to discard the unused symbols. We will discuss this further with the Cadence engineers but having some repro case that can be merged would likely be useful. TODO: make a github issue describing this linker behavior in more detail. Also, this refactor addresses the sign flip in fully_connected: http://b/138810107
55870c4    to
    0a469be      
    Compare
  
    | context, sizeof(OpDataFullyConnectedReference)); | ||
| } | ||
|  | ||
| #if !defined(XTENSA) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pnikam-cad @nyadla-sys @kpraving
These ifdefs are very undesirable but currently needed because the xtensa linker does not appear to be dropping unused symbols.
If I build the keyword_benchmark without this define the size is ~2.5 KB larger than if I keep this define. However, all the functions within this define are unused and I would expect the linker to be able to drop those symbols.
I will give more details tomorrow but wanted to flag this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was user-error on my part -- I very likely missed a make clean.
Filed #46261 to fix the underlying cause.
| @freddan80 @felix-johnny: we're trying to see what might work to increase code-sharing between reference and optimized kernels. Let us know if you have any suggestions. @yair-ehrenwald: this PR is a stab at improved code sharing. Let us know if you have any suggestions. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njeffrie and I talked some more (see the review comments).
| return EvalFloat(context, node, params->activation, input, filter, bias, | ||
| output); | ||
| return EvalFloatFullyConnectedReference(context, node, params->activation, | ||
| input, filter, bias, output); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better to avoid the intermediate EvalFloatFullyConnectedReference altogether:
  tflite::reference_ops::FullyConnected(
      ToFloatParams(activation), tflite::micro::GetTensorShape(input),
      tflite::micro::GetTensorData<float>(input),
      tflite::micro::GetTensorShape(filter),
      tflite::micro::GetTensorData<float>(filter),
      tflite::micro::GetTensorShape(bias),
      tflite::micro::GetTensorData<float>(bias),
      tflite::micro::GetTensorShape(output),
      tflite::micro::GetTensorData<float>(output));There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| return kTfLiteOk; | ||
| } | ||
|  | ||
| TfLiteStatus EvalQuantizedInt8FullyConnectedReference( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove in favor of directly calling the reference implementation, now that the OpData to OpParams is a single function call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| } | ||
| }; | ||
|  | ||
| extern const int kFullyConnectedInputTensor; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll make these:
inline int fully_Connected_input_tensor_index() { return 0};There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kept as extern const.
| } | ||
|  | ||
| #if !defined(XTENSA) | ||
| TfLiteStatus CalculateOpDataFullyConnectedReference( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll keep this function in common
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| return kTfLiteOk; | ||
| } | ||
|  | ||
| TfLiteStatus EvalQuantizedFullyConnectedReference( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|  | ||
| namespace tflite { | ||
|  | ||
| struct OpDataFullyConnectedReference { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call this OpDataFullyConnected (drop the reference suffix).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
49ed518    to
    14be7a3      
    Compare
  
    14be7a3    to
    fa2d61e      
    Compare
  
    fa2d61e    to
    bbd388b      
    Compare
  
    | Ready for review again. | 
| @advaitjain I'll havbe a look at it today | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice! I have a minor clean-up comment.
| TfLiteStatus CalculateOpData(TfLiteContext* context, | ||
| TfLiteFusedActivation activation, | ||
| TfLiteType data_type, const TfLiteTensor* input, | ||
| const TfLiteTensor* filter, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CalculateOpData() can be removed if we propagate "data->buffer_idx = -1;" into Prepare().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
a2a53b8    to
    c8bc253      
    Compare
  
    
Summary:
Also, this refactor addresses the sign flip in fully_connected: http://b/138810107