Skip to content

[FSDP] FSDP doesn't work (random accuracy performance) when using param_init_fn and sync_module_states=True #105840

@pacman100

Description

@pacman100

🐛 Describe the bug

Currently, when using FSDP, the model is loaded for each of the N processes completely on CPU leading to huge CPU RAM usage. When training models like Flacon-40B with FSDP on a dgx node with 8 GPUs, it would lead to CPU RAM getting out of memory because each process is loading 160GB (40B x 4Bytes (FP32)) in CPU RAM for a total of 160*8=1280GB requirement which results in script getting killed due to out of CPU RAM.

To combat this, we are trying to load the model only on rank 0 and have it on meta device when rank!=0. Then use param_init_fn along with sync_module_states=True for FSDP to properly init the weights on other ranks and broadcast the params from rank 0 to other ranks. This is trying to achieve what zero.init() from DeepSpeed does. it would be great for FSDP too to support this out of the box

However, when using above approach, the metrics in terms of accuracy and F1 scores are random, ie., the model isn;t learning anything even though the weights seem to change and train loss seems to decrease a little bit.

Code: https://github.com/pacman100/ram_efficient_fsdp

Steps:

  1. pip install -r requirements.txt
  2. bash run.sh

The FSDP config is in config.yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: BertLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Output:

[2023-07-24 16:34:14,815] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-07-24 16:34:19,709] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-07-24 16:34:19,736] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
DistributedType.FSDP
wandb: Currently logged in as: smangrul. Use `wandb login --relogin` to force relogin
Found cached dataset glue (/raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1269.59it/s]
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in /home/sourab/ram_efficient_fsdp/wandb/run-20230724_163421-hshg5m1t
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run whole-fire-12
wandb: ⭐️ View project at https://wandb.ai/smangrul/fsdp_glue_no_trainer
wandb: 🚀 View run at https://wandb.ai/smangrul/fsdp_glue_no_trainer/runs/hshg5m1t
Found cached dataset glue (/raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1229.52it/s]
Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c8be3b6cea2b5568.arrow
Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-98f28c5bb15f064c.arrow
accelerator.process_index=1 model.bert.pooler.dense.weight=Parameter containing:
tensor(..., device='meta', size=(768, 768), requires_grad=True)
accelerator.process_index=1 model.classifier.weight=Parameter containing:
tensor(..., device='meta', size=(2, 768), requires_grad=True)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
accelerator.process_index=0 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158,  ...,  0.0244, -0.0008,  0.0240],
        [ 0.0020,  0.0151,  0.0033,  ...,  0.0180, -0.0023,  0.0231],
        [-0.0386,  0.0145,  0.0621,  ...,  0.0374, -0.0105, -0.0395],
        ...,
        [-0.0111,  0.0136,  0.0541,  ...,  0.0666,  0.0017, -0.0090],
        [ 0.0001,  0.0024, -0.0125,  ...,  0.0046, -0.0014, -0.0079],
        [ 0.0415,  0.0751,  0.0305,  ...,  0.0317,  0.0479,  0.0080]],
       requires_grad=True)
accelerator.process_index=0 model.classifier.weight=Parameter containing:
tensor([[-0.0025,  0.0011, -0.0052,  ..., -0.0212,  0.0227,  0.0206],
        [ 0.0151, -0.0045,  0.0243,  ..., -0.0208, -0.0183, -0.0203]],
       requires_grad=True)
FullyShardedDataParallel(
  (_fsdp_wrapped_module): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x FullyShardedDataParallel(
            (_fsdp_wrapped_module): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
                (intermediate_act_fn): GELUActivation()
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (classifier): Linear(in_features=768, out_features=2, bias=True)
  )
)
accelerator.process_index=0 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158,  ...,  0.0244, -0.0008,  0.0240],
        [ 0.0020,  0.0151,  0.0033,  ...,  0.0180, -0.0023,  0.0231],
        [-0.0386,  0.0145,  0.0621,  ...,  0.0374, -0.0105, -0.0395],
        ...,
        [-0.0111,  0.0136,  0.0541,  ...,  0.0666,  0.0017, -0.0090],
        [ 0.0001,  0.0024, -0.0125,  ...,  0.0046, -0.0014, -0.0079],
        [ 0.0415,  0.0751,  0.0305,  ...,  0.0317,  0.0479,  0.0080]],
       device='cuda:0', requires_grad=True)
accelerator.process_index=0 model.classifier.weight=Parameter containing:
tensor([[-0.0025,  0.0011, -0.0052,  ..., -0.0212,  0.0227,  0.0206],
        [ 0.0151, -0.0045,  0.0243,  ..., -0.0208, -0.0183, -0.0203]],
       device='cuda:0', requires_grad=True)
accelerator.process_index=1 model.bert.pooler.dense.weight=Parameter containing:
tensor([[-0.0013, -0.0381, -0.0158,  ...,  0.0244, -0.0008,  0.0240],
        [ 0.0020,  0.0151,  0.0033,  ...,  0.0180, -0.0023,  0.0231],
        [-0.0386,  0.0145,  0.0621,  ...,  0.0374, -0.0105, -0.0395],
        ...,
        [-0.0111,  0.0136,  0.0541,  ...,  0.0666,  0.0017, -0.0090],
        [ 0.0001,  0.0024, -0.0125,  ...,  0.0046, -0.0014, -0.0079],
        [ 0.0415,  0.0751,  0.0305,  ...,  0.0317,  0.0479,  0.0080]],
       device='cuda:1', requires_grad=True)
accelerator.process_index=1 model.classifier.weight=Parameter containing:
tensor([[-0.0025,  0.0011, -0.0052,  ..., -0.0212,  0.0227,  0.0206],
        [ 0.0151, -0.0045,  0.0243,  ..., -0.0208, -0.0183, -0.0203]],
       device='cuda:1', requires_grad=True)
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Memory before entering the train : 212
Memory consumed at the end of the train (end-begin): 659
Peak Memory consumed during the train (max-begin): 1995
Total Peak Memory consumed during the train (max): 2207
epoch 0: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Memory before entering the train : 966
Memory consumed at the end of the train (end-begin): -94
Peak Memory consumed during the train (max-begin): 1218
Total Peak Memory consumed during the train (max): 2184
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/memory.py:303: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Memory before entering the train : 966
Memory consumed at the end of the train (end-begin): -94
Peak Memory consumed during the train (max-begin): 1297
Total Peak Memory consumed during the train (max): 2263
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
Memory before entering the eval : 872
Memory consumed at the end of the eval (end-begin): 94
Peak Memory consumed during the eval (max-begin): 209
Total Peak Memory consumed during the eval (max): 1081
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:                accuracy ▁▁▁
wandb:  eval_total_peak_memory ▁▁▁
wandb:                      f1 ▁▁▁
wandb:              train_loss █▂▁
wandb: train_total_peak_memory ▃▁█
wandb: 
wandb: Run summary:
wandb:                accuracy 0.68382
wandb:  eval_total_peak_memory 1081
wandb:                      f1 0.81223
wandb:              train_loss 0.63513
wandb: train_total_peak_memory 2263
wandb: 
wandb: 🚀 View run whole-fire-12 at: https://wandb.ai/smangrul/fsdp_glue_no_trainer/runs/hshg5m1t
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230724_163421-hshg5m1t/logs

As you can see, the performance remains same across the 3 epochs which is random performance.

Expected Behaviour:

Model learns when using param_init_fn and sync_module_states=True with FSDP so that the pretrained model can be loaded only on rank_0 and it can be meta for rank!=0. This is required for FSDP to be usable with large models in practice.

Versions

PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] lion-pytorch==0.0.6
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] pytorch-lightning==1.9.0
[pip3] pytorch-triton==2.1.0+9e3e10c5ed
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchmetrics==0.11.4
[pip3] torchvision==0.15.2+cu118
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] lion-pytorch 0.0.6 pypi_0 pypi
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.6 py310h1128e8f_1
[conda] mkl_random 1.2.2 py310h1128e8f_1
[conda] numpy 1.24.4 pypi_0 pypi
[conda] numpy-base 1.25.0 py310hb5e798b_0
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-lightning 1.9.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-triton 2.1.0+9e3e10c5ed pypi_0 pypi
[conda] torch 2.0.1+cu118 pypi_0 pypi
[conda] torchaudio 2.0.2+cu118 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchvision 0.15.2+cu118 pypi_0 pypi

cc @zhaojuanmao @mrshenli @rohan-varma @awgu

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions