Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Used per-parameter FSDP #165

Merged
merged 1 commit into from Mar 28, 2024
Merged

Used per-parameter FSDP #165

merged 1 commit into from Mar 28, 2024

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Mar 26, 2024

Numeric Parity
1D FSDP

  • Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8, sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter
    • FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS
    • FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS
    • FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS
    • FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS
    • Loss curves match between FSDP1 and FSDP2
    • Memory numbers reported as percentage since that is how they are logged; can convert against 95.0396 GiB GPU memory
  • Compile: same setup as eager
    • FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved, 7200-7500 WPS, 33% MFU
    • FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved, 7200-7500 WPS, 33% MFU
    • FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved, 8100-8400 WPS, 36% MFU
    • Loss curves slightly better than eager
    • For fun -- how much can we push MFU?
      • If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23 GiB (92.84%) peak reserved, 8600 WPS, 38% MFU.
      • If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB (94.99%) peak reserved, 9100-9300 WPS, 40% MFU.
  • Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel instead of two and (2), reshard_after_forward=False for the last transformer block

2D FSDP

  • Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs, local batch size 16 (to preserve global batch size), sequence length 2048, bf16 mixed precision, fp32 reduce-scatter
    • FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS
    • FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS
  • Loss curves match 8-way FSDP
  • FSDP1 + SP has incorrect numerics due to the FSDP.clip_grad_norm_ not all-reducing over TP mesh dimension
Loss curves Screenshot 2024-03-26 at 3 31 19 PM

Meta-Device Initialization

  • The PyTorch Core guideline is for module.reset_parameters() to only initialize parameters/buffers immediately owned by module (i.e. module.parameters(recurse=False) and module.buffers(recurse=False)).
  • This makes it challenging to specify custom initializations for core modules like nn.Linear and nn.Embedding. For example, in @lessw2020's depth-wise truncated normal initialization, the trunc_normal_ standard deviation depends on the layer ID, which is a property of the TransformerBlock but affects the child nn.Linears.
  • To disambiguate, I suggest avoiding the name reset_parameters() in the case that we violate the PyTorch Core guideline and instead use a different name (e.g. init_weights).

DCP & Save/Load

  • Tested 1D and 2D by specifying checkpoint_folder = "/tmp/checkpoint_andgu in the .toml, training until saving a checkpoint, terminating the run, and restarting the training to load the checkpoint -- the loss after loading looks reasonable

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 26, 2024
@awgu awgu force-pushed the per_param_land branch 2 times, most recently from e9a9c11 to 52e7e01 Compare March 26, 2024 19:48
@awgu awgu marked this pull request as ready for review March 26, 2024 21:18
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to not including this 'trick' since it might be confusing. The idea is that we basically can reshard_after_forward=False for the last transformer block for free.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wonderful work!
Left some comments, some of which are my questions.

torchtrain/models/llama/model.py Outdated Show resolved Hide resolved
@@ -333,13 +313,13 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.init_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems self.init_weights() or self.reset_parameters() are called in all but Attention and FeedForward modules (probably because init_std is not available during __init__?).

This creates a bit inconsistency in terms of how many times a parameter/buffer is initialized. Does it make sense to unify the behavior, e.g. all init_weights() or reset_parameters() are called from parent other than the Transformer itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following offline discussion, I changed it so that self.init_weights() is only called in Transformer.__init__() and not in any other __init__(). This meant one change to the RotaryEmbedding.__init__() to register the freqs_cis buffer. The rest remains the same.

@@ -359,6 +339,16 @@ def forward(self, tokens: torch.Tensor):
freqs_cis = self.freqs_cis[0:seqlen]
return h, freqs_cis

def init_weights(self):
if hasattr(self, "freqs_cis"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correctly that currently, each branch of this if-else will be called once during meta init; and the first branch will be called again when model.init_weights() is called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

@tianyu-l tianyu-l mentioned this pull request Mar 26, 2024
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great first pass! mainly have some confusions about meta init part

@@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I have some confusions about the reset_parameters guideline, so reset_parameters is an optional method in nn.Module, and it does not "recursively" call into the submodule's reset_parameters call when calling the parent module's reset_parameters().

This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?

And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters() on their defined root module's reset_parameters() function and not worrying about the attention layer wq/wk/wv be overriden by the builtin nn.Linear's reset_parameter call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?

Not sure if you get my question haha, am I missing sth there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?

This is my understanding.

And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters() on their defined root module's reset_parameters() function and not worrying about the attention layer wq/wk/wv be overriden by the builtin nn.Linear's reset_parameter call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?

I agree with the approach you are mentioning

  • if we ignore FSDP
  • if we are using FSDP1 and every weight init does not depend on the original tensor shape

It happens to be that the weight init used for the Llama model in torchtrain does not depend on the original tensor shape (namely, the weight init is elementwise). However, this may not be the case for other models (e.g. those that compute fan-in/fan-out), in which case this approach would silently sample from the incorrect distribution.

FSDP1 calls reset_parameters() before sharding.

  • The current approach is aligned with the core guideline, so for FullyShardedDataParallel(module), FSDP1 calls submodule.reset_parameters() for each managed submodule in module.modules() (managed is defined by excluding any nested FullyShardedDataParallel modules or their children). This is the only way to ensure that each parameter is initialized exactly once.
  • If a parent Attention module re-initialized its Q/K/V linear modules, then FSDP1 would initialize the Q/K/V linears twice (once from Linear.reset_parameters() and once from Attention.reset_parameters()). This can still give a valid probability distribution, but it could give different values for a fixed seed compared to if the Linear.reset_parameters() were skipped (e.g. if not using FSDP and just calling model.reset_parameters() on the root model). This is not a major problem since it does not mean incorrect randomness but is still worth mentioning.
  • If we further call model.reset_parameters() after sharding with FSDP1, then we have 1D flattened sharded tensors, which no longer preserve the original tensor shape. Therefore, calling model.reset_parameters() at this point will give incorrect randomness in cases depending on the shape.

In summary, following the core guideline is the only way to guarantee that each parameter is initialized once and before sharding. The constraint to initialize once is not required for correct randomness but may help reproducibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, ok this make sense, so it is critical to only initialize it once for reproducibility when starting a fixed seed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the same time though, the DTensor RNG will be different than local, so I am not sure if this reproducibility argument makes sense. We would not be able to ensure same results for FSDP2 compared to a single-GPU non-DTensor setup.

torchtrain/parallelisms/parallelize_llama.py Show resolved Hide resolved
torchtrain/parallelisms/parallelize_llama.py Show resolved Hide resolved
@awgu awgu force-pushed the per_param_land branch 3 times, most recently from ee5087b to dbb793a Compare March 27, 2024 19:09
@awgu awgu requested review from tianyu-l and wanchaol March 27, 2024 19:19
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! lgtm :)

torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm
)
torch.nn.utils.clip_grad_norm_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the fact that it composes with existing impl instead of using a separate impl!

@awgu
Copy link
Contributor Author

awgu commented Mar 27, 2024

After pytorch/pytorch#122801 lands, the save/load with torch.compile should work. (I tested locally.)

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!

@@ -199,7 +197,6 @@ def main(job_config: JobConfig):

# torch.compile model for improved performance
if job_config.training.compile:
torch._inductor.config.allow_buffer_reuse = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since pytorch/pytorch#122444 landed, we can re-enable buffer reuse.

@@ -186,6 +179,11 @@ def main(job_config: JobConfig):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# set this as required by DTensor to work with `to_empty`
# TODO: remove in the future when enabled by default for wrapper subclasses
torch.__future__.set_swap_module_params_on_conversion(True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After pytorch/pytorch#122755, we can remove this call.

@awgu
Copy link
Contributor Author

awgu commented Mar 28, 2024

If anything breaks because of this PR, please ping me :)

@awgu awgu merged commit 6d3d906 into pytorch:main Mar 28, 2024
4 checks passed
@awgu awgu deleted the per_param_land branch March 28, 2024 18:54
@awgu
Copy link
Contributor Author

awgu commented Mar 28, 2024

Local batch size 6, torch.compile, bf16 mixed precision, no AC, reshard_after_forward=False for all transformer blocks, 8x H100s:
9250-9400 WPS, 40.9-41.5% MFU

lessw2020 pushed a commit that referenced this pull request Apr 18, 2024
**Numeric Parity**
1D FSDP
- Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8,
sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter
- FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS
- FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS
- FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS
- FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS
    - Loss curves match between FSDP1 and FSDP2
- Memory numbers reported as percentage since that is how they are
logged; can convert against 95.0396 GiB GPU memory
- Compile: same setup as eager
- FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved,
7200-7500 WPS, 33% MFU
- FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved,
7200-7500 WPS, 33% MFU
- FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved,
8100-8400 WPS, 36% MFU
    - Loss curves slightly better than eager
    - For fun -- how much can we push MFU?
- If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23
GiB (92.84%) peak reserved, 8600 WPS, 38% MFU.
- If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB
(94.99%) peak reserved, 9100-9300 WPS, 40% MFU.
- Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel
instead of two and (2), `reshard_after_forward=False` for the last
transformer block

2D FSDP
- Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs,
local batch size 16 (to preserve global batch size), sequence length
2048, bf16 mixed precision, fp32 reduce-scatter
- FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS
- FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS
- Loss curves match 8-way FSDP
- FSDP1 + SP has incorrect numerics due to the `FSDP.clip_grad_norm_`
not all-reducing over TP mesh dimension

<details>
<summary> Loss curves </summary>

<img width="732" alt="Screenshot 2024-03-26 at 3 31 19 PM"
src="https://github.com/pytorch/torchtrain/assets/31054793/59ec71cc-ad0a-4dd1-b5c6-a8cbf9ab5e85">

</details>


**Meta-Device Initialization**
- The PyTorch Core guideline is for `module.reset_parameters()` to only
initialize parameters/buffers immediately owned by `module` (i.e.
`module.parameters(recurse=False)` and `module.buffers(recurse=False)`).
- This makes it challenging to specify custom initializations for core
modules like `nn.Linear` and `nn.Embedding`. For example, in
@lessw2020's depth-wise truncated normal initialization, the
`trunc_normal_` standard deviation depends on the layer ID, which is a
property of the `TransformerBlock` but affects the child `nn.Linear`s.
- To disambiguate, I suggest avoiding the name `reset_parameters()` in
the case that we violate the PyTorch Core guideline and instead use a
different name (e.g. `init_weights`).

**DCP & Save/Load**
- Tested 1D and 2D by specifying `checkpoint_folder =
"/tmp/checkpoint_andgu` in the `.toml`, training until saving a
checkpoint, terminating the run, and restarting the training to load the
checkpoint -- the loss after loading looks reasonable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants