Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow full model compilation with collection outputs #1599

Merged
merged 3 commits into from
Mar 14, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Jan 19, 2023

Description

  • Update graph-building in compiler to account for case where all operations are supported by Torch-TRT, but the output is a collection
  • Enable 'psuedo-partitioning' for nearly-fully-compiled models for which the only non-supported aspect of the model is the format of the output (TRT cannot output complex collections)
  • Define a small subset of operation schemas which are allowed despite the flag require_full_compilation. These operations are packing and unpacking of Tuples/Lists, and some are already used in cases of require_full_compilation
  • Display warnings to users if any portion of the pseudo-partitioning is unexpected, for example the model being partitioned ends up in more than 3 segments (maximally - a Torch segment to preprocess collection inputs, a TRT segment to perform model logic, a Torch segment to post-process collection outputs) or if schemas falling outside of the collection subset are encountered in a Torch segment
  • Add end-to-end test case with minimal reproducing example of a failing model, repaired with the changes to the compiler
  • Add minor fix to lowering to remediate c++ compiler warning

This fix was designed to minimally alter the existing phases of model conversion and does not manually flatten/reconstruct complex collection outputs, but instead uses the existing partitioning infrastructure and engine-stitching paradigm to accomplish this.

Fixes #1598
Fixes #1368

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • MVP for New feature

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Jan 19, 2023
@github-actions github-actions bot added component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: partitioning component: tests Issues re: Tests labels Jan 19, 2023
core/compiler.cpp Outdated Show resolved Hide resolved
core/compiler.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly fine, just some ux and dev stuff

core/compiler.cpp Outdated Show resolved Hide resolved
core/compiler.cpp Outdated Show resolved Hide resolved
core/compiler.cpp Outdated Show resolved Hide resolved
core/partitioning/partitioning.h Outdated Show resolved Hide resolved
@gs-olive
Copy link
Collaborator Author

gs-olive commented Jan 23, 2023

Thanks for the comments and review @narendasan - I have incorporated the feedback and updated two of the user warnings to compilation-halting errors.

One note I wanted to make is that despite the min_block_size=1 and allowing collection-type-nodes to run in Torch, this implementation still respects full compilation and will not execute intermediate pack/unpack operations in Torch. This is because prim::TupleUnpack and other such operators are not automatically added to torch_executed_ops - this is only done in the case where input_signature is used, which is not the intent of this PR (it will be a future PR). As a result, only the collection ops needed to pack the final model output are run in Torch, as per this function:

// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
// fallback nodes that produce entire graph's nonTensor output
for (auto i : block->outputs()) {
if (!isTensor(i)) {
ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR);
}
}
// fallback nodes that consume entire graph's nonTensor input
for (auto i : block->inputs()) {
if (!isTensor(i)) {
for (auto use : i->uses()) {
ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR);
}
}
}
}

Any intermediate packing/unpacking is handled by the evaluators and does not cause a graph segmentation, since those nodes are not directly graph outputs.

core/compiler.cpp Outdated Show resolved Hide resolved
// executed in TRT, regardless of the size of the graph
if (expect_full_compilation) {
// If minimum block size is different from the default, the user must have specified it
if (ctx->settings.min_block_size != 3) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an issue to centralize defaults somewhere in the core

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a user sets min_block_size =3 as well? Does he still get the warning message?

Copy link
Collaborator Author

@gs-olive gs-olive Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the user would not get a warning message in that case. We currently don't have a way of knowing whether the user inputs a value or not, since the defaults are not centralized. There is an issue #1644 to address this, but as of now, your statement is correct. Additionally, it is worth noting that prior to this PR, if a user specified min_block_size and require_full_compilation=True, we would still ignore the min_block_size, but without warning.

core/compiler.cpp Show resolved Hide resolved
core/compiler.cpp Outdated Show resolved Hide resolved
core/compiler.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @bowang007 can you take a look?

Copy link
Collaborator

@bowang007 bowang007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// executed in TRT, regardless of the size of the graph
if (expect_full_compilation) {
// If minimum block size is different from the default, the user must have specified it
if (ctx->settings.min_block_size != 3) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a user sets min_block_size =3 as well? Does he still get the warning message?

(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection || user_requested_long)) ||
requires_collection_handling) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this If statement be optimized? Seems like isBlockConvertible and outputIsCollection overlap with require_collection_handling

Copy link
Collaborator Author

@gs-olive gs-olive Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this statement to make the conditions clearer, by removing the ! and distributing it over the conditionals inside. Other than this, the statement cannot be reduced any further since the requires_collection_handling boolean is independent of cfg.partitioning_info.enabled (since we want to partition in this case regardless of require_full_compilation=True)


// If full compilation is expected, cannot have more than 2 Torch segments
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have edge cases like 2 torch_segments for inputs/outputs? Does merge_adjacent_segments_of_same_type always merge them into one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not be a case where multiple Torch segments appear for inputs/outputs, since merge_adjacent_segments_of_same_type addresses this case, as you had mentioned. Since the tensors in question are inputs, it should not arise that segment.do_not_merge() is True, since the only approved operators falling into these segments are for collection construction, and only the prim::If or prim::Loop operators can induce a non-merge situation.

Comment on lines 375 to 379
if ((cfg.partitioning_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() != 0 ||
cfg.partitioning_info.forced_fallback_operators.size() != 0 || !isBlockConvertible || outputIsCollection ||
user_requested_long)) ||
requires_collection_handling) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the updates to the conditional logic to make it $$(cfg.partitioning\_info.enabled \wedge (x \vee y \vee z ... )) \vee requires\_collection\_handling$$

@gs-olive gs-olive requested a review from bowang007 March 10, 2023 01:48
- Update graph-building in compiler to account for case where all
operations are supported by Torch-TRT, but the output is a collection.
- Enable 'psuedo-partitioning' for nearly-fully-compiled models for
which the only non-supported aspect of the model is the format of the
output (TRT cannot output complex collections)
- Define a small subset of operation schemas which are allowed despite
the flag `require_full_compilation`. These operations are packing and
unpacking of Tuples/Lists, and some are already used in cases of
`require_full_compilation`
- Display warnings to users if any portion of the `pseudo-partitioning`
is unexpected, for example the model being partitioned ends up in more
than 3 segments (maximally - a Torch segment to preprocess collection
inputs, a TRT segment to perform model logic, a Torch segment to
post-process collection outputs) or if schemas falling outside of the
collection subset are encountered in a Torch segment
- Add end-to-end test case with minimal reproducing example of a failing
model, repaired with the changes to the compiler
- Add minor fix to lowering to remediate c++ compiler warning
- Add function to check the equivalence of two collection-based outputs
for comparison across Torch-TRT and Torch outputs
- Improved test robustness in end-to-end to check for equivalent output
schemas in addition to successful compilation
- Add test case to elicit behavior where full compilation is requested
but TRT engine size falls below default `min_block_size=3`
- Move `min_block_size` condition to narrow scope
- Coalesce logic to improve code readability
Comment on lines +376 to +382
// Partitioning is required if:
// 1. User requested some modules/operators fallback
// 2. The block (graph) cannot be converted due to operator coverage
// 3. The output of the graph is a collection
// 4. The user requested a non-TRT data type input
auto isPartitioningRequired =
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coalesced partitioning logic for readability

Comment on lines +329 to +332
bool userRequestedFallback(CompileSpec& cfg) {
return cfg.lower_info.forced_fallback_modules.size() != 0 ||
cfg.partitioning_info.forced_fallback_operators.size() != 0;
}
Copy link
Collaborator Author

@gs-olive gs-olive Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added helper function to determine if the user's input specifications imply fallback

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@bowang007 bowang007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@gs-olive gs-olive merged commit d0af394 into pytorch:main Mar 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: partitioning component: tests Issues re: Tests
Projects
None yet
4 participants