Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of Large Model Support in PyTorch #35633

Open
UlrichFreitag opened this issue Mar 29, 2020 · 20 comments
Open

Integration of Large Model Support in PyTorch #35633

UlrichFreitag opened this issue Mar 29, 2020 · 20 comments
Labels
feature A request for a proper, new feature. low priority We're unlikely to get around to doing this in the near future module: internals Related to internal abstractions in c10 and ATen module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@UlrichFreitag
Copy link

UlrichFreitag commented Mar 29, 2020

🚀 Feature

PyTorch Large Model Support (LMS) is a feature in the PyTorch provided by IBM here: here (official IBM repo) and here (fork of the main maintener of LMS) that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with "out-of-memory" errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed.

With LMS, deep learning models can scale significantly beyond what was previously possible and, ultimately, generate more accurate results.

Motivation

  • When training recurrent models with back-propagation through time (BPTT) it is often useful to 'truncate' the sequence length as little as possible, especially when dealing with audio inputs or EEG data that have high temporal resolution. This results in a larger memory footprint, and this is where LMS can save the day.
  • Also, the amount of compute needed to train state-of-the-art models doubles on average every 3.5 months (see https://openai.com/blog/ai-and-compute/). This comes both from the use of larger batch sizes and the use of larger models (like the now famous GPT-2 with 1.5B parameters). For instance the Transformer-XL can have a big memory footprint (https://openai.com/blog/sparse-transformer/). Using LMS is very useful when you want to test something out without using gradients checkpointing right away.
  • LMS can be extremely beneficial to anyone who cannot afford access to high-end GPUs (within small startups or in academic research). Using cloud services or buying the Titan RTX ($2,499) to run models is often too expensive.
  • GPU RAM is most of the time limited to about 8GB and is not extensible. Regular RAM on the other hand can easily be increased up to 128GB or more and is underused during trainings.
  • Finally, LMS could be useful when smoke testing runs with small GPUs (either manually or within the context of a CI). This leaves the small (often older) GPUs still busy while the larger ones are used for real runs with or without LMS.

Pitch (copy/paste from the doc of LMS)

One or more elements of a deep learning model can lead to GPU memory exhaustion.

These include:

  • Model depth and complexity
  • Base data size (for example, high-resolution images)
  • Batch size

Traditionally, the solution to this problem has been to modify the model until it fits in GPU memory. This approach, however, can negatively impact accuracy – especially if concessions are made by reducing data fidelity or model complexity.

Alternatives

Checkpointing can some sometimes helps. But that not always the cases...

Additional context

This feature is maintained for a while (since at least PyTorch 1.1) by @mtbrandy and is proposed for contribution to PyTorch since at least August 2019 (I did not found any mention of it on this repo):

https://www.reddit.com/r/pytorch/comments/cgyppk/large_model_support_for_pytorch/
https://discuss.pytorch.org/t/very-large-model-on-single-gpu/28962

It is as well mentionned here:
https://www.ibm.com/support/knowledgecenter/SS5SF7_1.5.4/navigation/pai_getstarted_pytorch.html

Official repos:
https://github.com/IBM/pytorch-large-model-support/
https://github.com/mtbrandy/pytorch


I am basically creating this issue because I really like LMS. So far I have waited the support of LMS for each version of PyTorch. Each time I had to manually compile PyTorch (and create wheels) to have the support of it. (BTW, many thanks to @mtbrandy that still maintains this fork).

The thing that I am missing is why this feature has not been integrated in PyTorch even though the code is made by IBM (and maintained) 😅.

I mean, it needs an "opt-in" from the user, so it is not enabled by default! If the reason is "it can reduce the speed performance". I agree with you, but it can also allows people to experiment more without the need of a super-expansive GPU. I really think that the community, small start-ups, students etc. would benefits from this even if they will surely not use that most of the time.

@mtbrandy
Copy link
Contributor

@UlrichFreitag Thank you for your testimonial. And also for starting up this conversation again -- we would love to get the feature integrated into an official version of PyTorch.

I'll also take this opportunity to mention a some recent developments:

  • The latest iteration of the implementation includes improved usability (no tuning required) and performance (speculative paging to help hide swapping latencies).
  • We have also ported this design to Tensorflow 2.0 (see https://github.com/IBM/tensorflow-large-model-support/).

We've been getting good feedback and results due to its ease of use and low performance overhead (especially when used on a system with NVLINK) -- particularly from researchers wanting to push the boundaries of model complexity and/or data fidelity.

@ailzhang ailzhang added needs research We need to decide whether or not this merits inclusion, based on research world triage review labels Mar 30, 2020
@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2020

Related #31252

@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2020

Also related #28997

@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2020

To elaborate on this comment. In general the PyTorch philosophy is to give user full control on the things they want to do, and avoid having the framework being in the business of taking care of some sort of black box heuristics to try to speed code up magically. A good example of this is how we have implemented checkpointing: a traditional checkpointing framework will apply heuristics to determine where a checkpoint happens, whereas in PyTorch the user can express exactly where the checkpoint happens. Under this philosophy, LMS as currently implemented is not a good fit for PyTorch core. Gregs suggestion above is to try to figure out if there is a way to do LMS in the "PyTorch" style.

This might be quite hard to do, and sometimes a black box heuristic framework is just what the doctor ordered. We don't mind if our users decide that this is the right thing for their application! (We just don't want to support it.) So to me, the biggest problem you are facing here is LMS is a fork of PyTorch, and a quite invasive fork at that. That makes keeping it up to date quite hard. But there might be some other ways to implement LMS without needing to fork PyTorch. One way is to represent LMS tensors as a new tensor type in PyTorch. Users could specify that their tensors are lms() (similar to how they move things to cuda()), and then the implementation of operations on LMS tensors transparently migrate tensors on/off CUDA in the same way that the existing LMS implementation does. It sounds like you needed your own CUDA allocator for this case too, so you'd just maintain your own CUDA allocator in that case. How does this sound?

@mtbrandy
Copy link
Contributor

@ezyang It sounds like you're saying that there is distinction in both the responsibility of maintenance and philosophy of design between a "core" tensor type like cuda and another type like the proposed lms.

Can you elaborate on that? For example, how is one expected to be maintained versus the other and what makes one exempt from some core design principles that may be enforced elsewhere?

@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2020

Well, when things get merged into PyTorch master, we collectively take on maintenance responsibilities for it. So if we say that we don't want a type like "lms" in core, it's because we don't want to be on the hook for maintaining it :) Part of how we make that decision has to do our design philosophy. But we're not here to tell you what you, gentle user, are allowed to or not allowed to do. We just don't want to maintain it :)

@mtbrandy
Copy link
Contributor

OK, I'm still trying to understand your suggestion. As an alternative to forking PyTorch (or accepting LMS into core), you are proposing a separate python package that can serve as a PyTorch extension and provide the implementation of a new tensor type? Are there existing examples of non-core tensor types you can point me at?

Also, I wonder if there is some middle ground where we introduce the concept of an extensible and dynamic memory management API to core. This implementation that we call "Large Model Support" is really just one way to virtualize the GPU memory. In fact, LMS isn't even a particularly great (or loved) name for what it does. Would you be any more interested in going down that path?

@thyeros
Copy link

thyeros commented Mar 31, 2020

This is the feature I discussed with @gottbrath in SC19. The user-base favoring this capability (aka LMS, but essentially GPU memory virtualization) grows fast, and we (@mtbrandy here leads the effort) love to talk about this tech getting into pytorch in some way.

@ezyang
Copy link
Contributor

ezyang commented Apr 1, 2020

Are there existing examples of non-core tensor types you can point me at?

The best and most complete external integration is https://github.com/pytorch/xla which adds XLA tensor support. The repo is quite big though. There is a small toy example C++ extension integration here https://github.com/pytorch/pytorch/blob/master/test/cpp_extensions/msnpu_extension.cpp which you might be able to work off of, although some aspects of the API here are still likely to be in flux.

@ezyang
Copy link
Contributor

ezyang commented Apr 1, 2020

One thing to note is that in both of these cases each function is implemented individually. This is likely the wrong choice for LMS, since it means whenever we add a new function you would need to add a new LMS wrapper for it (which is frankly worse than what your current fork's maintenance needs are). So you'd probably want to write a polymorphic fallback function as in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/test/backend_fallback_test.cpp#L74

This is also in flux. So the correct answer in the short term might be to wait until we stabilize these APIs some more. But I think an LMS-like mechanism is something that should be possible to implement on top of the infrastructure here.

@danarteaga
Copy link

Very promising to hear that progress is being made with LMS and official integration. I have been using the LMS fork by @mtbrandy in PyTorch for ~9 months on a system with limited GPU memory and have found that the process of swapping tensors to host memory has allowed me to implement many more parameters into my model and achieve much faster training times due to the ability to use larger patch sizes.

@jqyin
Copy link

jqyin commented Apr 2, 2020

LMS is a valuable feature when we try BERT large that can not fit into a GPU memory. There're model parallel workarounds but that increases software complexity and is usually not generic.

@ezyang ezyang added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: internals Related to internal abstractions in c10 and ATen module: memory usage PyTorch is using more memory than it should, or it is leaking memory and removed triage review labels Apr 6, 2020
@gchanan gchanan added low priority We're unlikely to get around to doing this in the near future and removed needs research We need to decide whether or not this merits inclusion, based on research world labels Apr 28, 2020
@aluo-x
Copy link

aluo-x commented Jun 18, 2020

I would also like to add a note of support. It would be nice to have some kind of explicit API.

A few questions, how well does the fork in the current state work with DataParallel & DistributedDataParallel?

I imagine an ideal API, would allow us to call .lms() on .cuda() tensors or models or nn.Sequential objects. For example:

network = resnet50().cuda()
network.lms(threshold=??, aggressiveness=??, allowed_types=(float32, long), exclude=filter_func)
# This enables lms support for all tensors in the network
# Where threshold or agressiveness provide some kind of control over how pro-active it is. And the allowed_types control what kind can be moved, or an explicit blacklist via a filtering function

Alternatively we could apply this in a modular fashion when we define nn.Module objects

@sandias42
Copy link

Wanted to add my support for this feature (or a variation) that helps make something like LMS more easily accessible to the pytorch community. Everyone who I have shared the LMS fork with has been amazed at what a big difference having this feature makes in the convenience of model development, particularly for people who don't have unlimited access to 32 GB V100 servers or whatever and need to rely on smaller memory GPUs in their research for cost and access reasons.

I think anything that makes LMS more accessible will make it significantly easier for researchers without large memory GPUs to experiment with SOTA-sized models without imposing extra development costs for model-parallel workarounds. In that way, I'd expect LMS to dis-proportionally help researchers from low-resource settings, which hasn't been explicitly mentioned above but seems like a very worthwhile goal (on top of the other benefits already mentioned).

@markdjwilliams
Copy link

Has there been any further progress here?

@stanleyshly
Copy link

Same, this seems like an extremely useful, I know this is marked low priority but I was wondering if any plans have been make to integrate Large Model Support yet.

@aluo-x
Copy link

aluo-x commented Aug 22, 2021

I feel like as of now, a better approach is to utilize fairscale. It is semi-officially from the many of the same authors of pytorch, and it supports CPU offloading of activations (using checkpoint) and CPU offloading of weights (using OffloadModel).

Currently fairscale is still very actively developed.
@stanleyshly @markdjwilliams @sandias42

@stalek71
Copy link

stalek71 commented Nov 4, 2021

Hello @mtbrandy,
I was trying to integrate the changes from your patch (applied on pyTorch v1.5) with pyTorch 1.10, but I was not able to complete it.

For instance c10/cuda/CUDACachingAllocator.h file in v.1.10 has its own definition of init() function, while your patch introduced it also. function_wrapper.py file is empty now but your patch was modifying it also etc.

Is it possible you could provide an updated patch in ibm/pytorch-large-model-support repo :) ?

@hwkim1127
Copy link

So... um... sell kidney if I need more VRAM then?

@aluo-x
Copy link

aluo-x commented May 8, 2023

There are tons of tools now that can do any of the following:

  1. Model weight offloading to CPU/Disk
  2. Activation checkpointing/offloading
  3. Optimizer sharding
  4. Automatic mixed precision (AMP)

Checkout pytorch's FSDP & AMP, the fairscale toolbox (partially merged into pytorch as FSDP), and the hugginface accelerate wrapper (wraps deepspeed, FSDP, Megatron-LM).

Edit: Also for more generic offloading, checkout pytorch keops for very dense + differentiable operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. low priority We're unlikely to get around to doing this in the near future module: internals Related to internal abstractions in c10 and ATen module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests