-
Notifications
You must be signed in to change notification settings - Fork 325
Add DTensor LLaMA inference model: simple_gpt #1867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2b39a5d
to
d8d42b5
Compare
52dd24b
to
ecddecc
Compare
ecddecc
to
f365e24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
) | ||
|
||
fabric = L.Fabric(devices=[self._rank], precision="bf16-true") | ||
with fabric.init_module(empty_init=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, is the only use of Lightning
this init_module? I'm not sure what this does, but can we remove it by initializing with meta device? Then the implementation doesn't rely on third-party libraries can stay as close to native pytorch as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point, let me update this with the latest init code that removed the lightning dep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've actually updated the initial code to avoid the lightning dependency (and it's actually much faster too!)
https://github.com/pytorch-labs/simple_gpt/blob/main/generate.py#L162
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For nightly runs, do we usually load the actual weights? The load time should be quick, but the weights file is 10+GB just for LLaMA-7B. Otherwise, I could also update the default weights initialization to use random instead of just torch.zeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the real weights otherwise you might fail accuracy checks in unpredictable ways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed if we only run this model through dynamorunner (model requires distributed, multiple GPUs)? The unit tests from this repo are skipped i.e. python test.py -k 'test_simple_gpt_', which includes the accuracy checks
This is because some models in torchbench (e.g. maml) does need gradient calculations in inference mode. So we leave the choice of whether enabling the gradient context to the eval test code. There is an open issue about testing that the |
@pytorchbot merge |
@xmfan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…mo runner (#108438) Adding support to pass rank and world_size to torchbench model, via its extra_args parameter: https://github.com/pytorch/benchmark/blob/main/torchbenchmark/util/model.py#L83C80-L83C90 This is used for models which distribute over multiple GPUs e.g. simple_gpt pytorch/benchmark#1867 Also add an option to skip multiprocess only gpu models Testing via `python benchmarks/dynamo/torchbench.py -d cuda --output=benchmark_logs/performance.csv --inference --performance --timing --print-memory --multiprocess --only simple_gpt` Pull Request resolved: #108438 Approved by: https://github.com/Chillee
Adds simple_gpt + DTensor implemented in https://github.com/pytorch-labs/simple_gpt/pull/7 to torchbench
Tested via
python benchmarks/dynamo/torchbench.py -d cuda --output-directory=benchmark_logs --output=performance.csv --inference --performance --timing --print-memory --multiprocess --nothing --only simple_gpt
. Note: --nothing is used here to disable compile, since DTensor + compile isn't yet supported in main2 changes were required to the model: