-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF-TRT Add n_build_pass attribute #52033
Conversation
In the current PR, _n_build_pass is a value pass from a user facing config to the TRTEngineOp to help the implementation of a TF-TRT API that takes n user inputs, performs shape profiling and builds cuda engines. Changing the user facing config for this purpose is not necessary, as the implementation of the TF-TRT API can do these instead:
|
Thanks @bixia1 for the comment. Indeed, if the approach that you describe would work, then we do not need this PR and we can close. There is one complication though:
What is the right way to overcome this problem?
|
We can modify the _profile_generation_mode attribute in the TRTEngineOp inside the SavedModelBundle graphdef, without creating a new session. I think we are already doing something similar to this in the Python converterV2.build method. Do you agree? Another problem with having _n_build_pass attribute in the TRTEngineOp and relying on its the value to become 0 in order to trigger the build of the cuda engine is that, the number of inputs provided by users may not be the same as the number of time we execute the TRTEngineOp, in the presence of loops and if-stmt. |
I agree with you, that it is desirable to use the existing
How to do that, could you point me to an example? It is not clear to me which API to use. Below is one example where I rewrite the GraphDef, but that does not change the actual graph used by the session. I need to create a new session for the changes to take effect. // Create end execute a graph with a single const op
Scope root = Scope::NewRootScope();
auto c = Const(root.WithOpName("my_const"), {{42.0f, 137.0f}});
ClientSession session(root);
std::vector<Tensor> outputs;
Status status = session.Run({c}, &outputs);
if (status.ok()) std::cout << outputs[0].DebugString() << std::endl;
// Get the graph def and rewrite the constant value
GraphDef gdef;
status = root.ToGraphDef(&gdef);
tensorflow::Tensor new_val(tensorflow::DT_FLOAT,
tensorflow::TensorShape({1, 2}));
float *tensor_flat = new_val.flat<float>().data();
tensor_flat[0] = 31;
tensor_flat[1] = 41;
NodeDef *node = gdef.mutable_node(0);
TensorProto *tensor_attr = (*node->mutable_attr())["value"].mutable_tensor();
new_val.AsProtoTensorContent(tensor_attr);
// Changing the graph def has no effect on the results
status = session.Run({c}, &outputs);
if (status.ok()) std::cout << outputs[0].DebugString() << std::endl;
// Alternative: create a new session with the modified graph def:
std::unique_ptr<tensorflow::Session> session2(
tensorflow::NewSession(tensorflow::SessionOptions()));
status = session2->Create(gdef);
status = session2->Run({}, {"my_const"}, {}, &outputs);
std::cout << outputs[0].DebugString() << std::endl;
It is not clear to me how to get the Graph object from the current session, and which API to use to manage the session. Could you give suggestions for the above example, on how to import the modified GraphDef back into the original session? |
Closing this PR as it is not necessary. One can use Some notes:
|
To convert a model with TF-TRT dynamic shapes, one can provide profile information by inferring the segmented model a number of times. Currently the build mode uses _profile_generation_mode attribute of the graph to declare that we are in build mode. This requires two rewrites of the graph. Here is the current workflow
_profile_generation_mode=True
_profile_generation_mode=False
This PR introduces a new attribute for TRTEngineOp:
_n_build_pass
, this is can be set by the rewriter config (example). If we set this parameter during conversion time then we can avoid rewriting the graph to enable/disable build mode.This PR is not essential to C++ conversion of TF-TRT, it just decreases the number of graph rewrite steps. Alternatively to this PR we could:
_profile_generation_mode
. No changes needed in the TRT optimization pass, but extra work in the C++ converter._profile_generation_mode
to true already during graph conversion. We would need to read this value from the rewriter config (same way as this PR reads the_n_build_pass
) attribute.Tagging @bixia1 for discussing these points and for review.
Tracker: #52012