[model loading] don't gc.collect()
if only 1 shard is used
#36721
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a test-related regression introduced by #36033. Despite speeding up
from_pretrained
in general, most tests that relied onfrom_pretrained
became much slower (see tests section below).Measuring time, the entire slowdown can be traced to a single line, a
gc.collect()
. This line was not added in #36033, it was moved. Before #36033, if the checkpoint was not sharded, astate_dict
would have been passed to_load_pretrained_model
, and thegc.collect()
branch would not be reached.This PR adds a tiny
if
to skipgc.collect()
if the checkpoint is not sharded. Since many tests rely on unshardedfrom_pretrained
, we can immediately observe faster tests.Tests
py.test tests/models/gpt2/test_modeling_gpt2.py
times, which includes a mix of tests with and withoutfrom_pretrained
, on my machine:from_pretrained
#36033: 25.41smain
: 61.14sfrom_pretrained
#36033 resulted in 15%+ test speedup, after this fix)