⚡ DeepSpeed ZeRO Stage 2 model parallel training #2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
To prevent out-of-memory (OOM) errors when running the Transformer models! Change from distributed data parallel (DDP) to a data+model parallel strategy
Current State
As of 9793587, we have been using distributed data parallel (DDP) to split the data batch-wise across multiple GPUs. However when running on a full-size Sentinel-2 image (batch_size=1) during test phase (#1), this can already cause out-of-memory issues for our Super-Resolution Segmentation task.
Future State
One possible solution is to shard the neural network model itself across multiple GPUs. This reduces the GPU memory requirements and allows for larger models and/or bigger datasets to be used for training/inference.
Specifically, we'll be switching to use DeepSpeed (https://github.com/microsoft/DeepSpeed) which offers several 'levels' of model sharding, and . See https://devblog.pytorchlightning.ai/experiment-with-billion-parameter-models-faster-using-deepspeed-and-meta-tensors-2e9c255edd71 and https://huggingface.co/blog/zero-deepspeed-fairscale for a good explainer
Main DeepSpeed stages (from https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/model_parallel.html#deepspeed):
💡 Suggest to use Stage 2 instead of Stage 3 because while Stage 3 improves memory use, it comes with increased latency from the cost of extra distributed communication.
Other benefits of using DeepSpeed:
Alternative strategies (and why they were not considered)
Pytorch-Lightning offers several other advanced training strategies. These might work well for other cases, but probably not for our specific project.
TODO:
Use Meta Tensors, c.f. https://devblog.pytorchlightning.ai/experiment-with-billion-parameter-models-faster-using-deepspeed-and-meta-tensors-2e9c255edd71NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
. See also General MPS op coverage tracking issue pytorch/pytorch#77764