-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Used per-parameter FSDP #165
Conversation
e9a9c11
to
52e7e01
Compare
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.activation_checkpoint | ||
) | ||
# As an optimization, do not reshard after forward for the last |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am open to not including this 'trick' since it might be confusing. The idea is that we basically can reshard_after_forward=False
for the last transformer block for free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wonderful work!
Left some comments, some of which are my questions.
torchtrain/models/llama/model.py
Outdated
@@ -333,13 +313,13 @@ def __init__(self, model_args: ModelArgs): | |||
super().__init__() | |||
self.model_args = model_args | |||
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) | |||
self.init_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems self.init_weights()
or self.reset_parameters()
are called in all but Attention
and FeedForward
modules (probably because init_std
is not available during __init__
?).
This creates a bit inconsistency in terms of how many times a parameter/buffer is initialized. Does it make sense to unify the behavior, e.g. all init_weights()
or reset_parameters()
are called from parent other than the Transformer
itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following offline discussion, I changed it so that self.init_weights()
is only called in Transformer.__init__()
and not in any other __init__()
. This meant one change to the RotaryEmbedding.__init__()
to register the freqs_cis
buffer. The rest remains the same.
torchtrain/models/llama/model.py
Outdated
@@ -359,6 +339,16 @@ def forward(self, tokens: torch.Tensor): | |||
freqs_cis = self.freqs_cis[0:seqlen] | |||
return h, freqs_cis | |||
|
|||
def init_weights(self): | |||
if hasattr(self, "freqs_cis"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I understanding correctly that currently, each branch of this if-else will be called once during meta init; and the first branch will be called again when model.init_weights()
is called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great first pass! mainly have some confusions about meta init part
@@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs): | |||
model_args.n_heads * self.head_dim, model_args.dim, bias=False | |||
) | |||
|
|||
def reset_parameters(self, init_std): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I have some confusions about the reset_parameters
guideline, so reset_parameters
is an optional method in nn.Module, and it does not "recursively" call into the submodule's reset_parameters
call when calling the parent module's reset_parameters()
.
This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?
And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters
to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters()
on their defined root module's reset_parameters()
function and not worrying about the attention layer wq/wk/wv
be overriden by the builtin nn.Linear's reset_parameter
call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?
Not sure if you get my question haha, am I missing sth there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?
This is my understanding.
And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters() on their defined root module's reset_parameters() function and not worrying about the attention layer wq/wk/wv be overriden by the builtin nn.Linear's reset_parameter call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?
I agree with the approach you are mentioning
- if we ignore FSDP
- if we are using FSDP1 and every weight init does not depend on the original tensor shape
It happens to be that the weight init used for the Llama model in torchtrain does not depend on the original tensor shape (namely, the weight init is elementwise). However, this may not be the case for other models (e.g. those that compute fan-in/fan-out), in which case this approach would silently sample from the incorrect distribution.
FSDP1 calls reset_parameters()
before sharding.
- The current approach is aligned with the core guideline, so for
FullyShardedDataParallel(module)
, FSDP1 callssubmodule.reset_parameters()
for each managedsubmodule
inmodule.modules()
(managed is defined by excluding any nestedFullyShardedDataParallel
modules or their children). This is the only way to ensure that each parameter is initialized exactly once. - If a parent
Attention
module re-initialized its Q/K/V linear modules, then FSDP1 would initialize the Q/K/V linears twice (once fromLinear.reset_parameters()
and once fromAttention.reset_parameters()
). This can still give a valid probability distribution, but it could give different values for a fixed seed compared to if theLinear.reset_parameters()
were skipped (e.g. if not using FSDP and just callingmodel.reset_parameters()
on the rootmodel
). This is not a major problem since it does not mean incorrect randomness but is still worth mentioning. - If we further call
model.reset_parameters()
after sharding with FSDP1, then we have 1D flattened sharded tensors, which no longer preserve the original tensor shape. Therefore, callingmodel.reset_parameters()
at this point will give incorrect randomness in cases depending on the shape.
In summary, following the core guideline is the only way to guarantee that each parameter is initialized once and before sharding. The constraint to initialize once is not required for correct randomness but may help reproducibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, ok this make sense, so it is critical to only initialize it once for reproducibility when starting a fixed seed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the same time though, the DTensor RNG will be different than local, so I am not sure if this reproducibility argument makes sense. We would not be able to ensure same results for FSDP2 compared to a single-GPU non-DTensor
setup.
ee5087b
to
dbb793a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! lgtm :)
torch.nn.utils.clip_grad_norm_( | ||
model.parameters(), job_config.training.max_norm | ||
) | ||
torch.nn.utils.clip_grad_norm_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the fact that it composes with existing impl instead of using a separate impl!
After pytorch/pytorch#122801 lands, the save/load with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me!
@@ -199,7 +197,6 @@ def main(job_config: JobConfig): | |||
|
|||
# torch.compile model for improved performance | |||
if job_config.training.compile: | |||
torch._inductor.config.allow_buffer_reuse = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since pytorch/pytorch#122444 landed, we can re-enable buffer reuse.
@@ -186,6 +179,11 @@ def main(job_config: JobConfig): | |||
model = models_parallelize_fns[model_name]( | |||
model, world_mesh, parallel_dims, job_config | |||
) | |||
# set this as required by DTensor to work with `to_empty` | |||
# TODO: remove in the future when enabled by default for wrapper subclasses | |||
torch.__future__.set_swap_module_params_on_conversion(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After pytorch/pytorch#122755, we can remove this call.
If anything breaks because of this PR, please ping me :) |
Local batch size 6, |
**Numeric Parity** 1D FSDP - Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8, sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter - FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS - FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS - FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS - FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS - Loss curves match between FSDP1 and FSDP2 - Memory numbers reported as percentage since that is how they are logged; can convert against 95.0396 GiB GPU memory - Compile: same setup as eager - FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved, 8100-8400 WPS, 36% MFU - Loss curves slightly better than eager - For fun -- how much can we push MFU? - If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23 GiB (92.84%) peak reserved, 8600 WPS, 38% MFU. - If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB (94.99%) peak reserved, 9100-9300 WPS, 40% MFU. - Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel instead of two and (2), `reshard_after_forward=False` for the last transformer block 2D FSDP - Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs, local batch size 16 (to preserve global batch size), sequence length 2048, bf16 mixed precision, fp32 reduce-scatter - FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS - FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS - Loss curves match 8-way FSDP - FSDP1 + SP has incorrect numerics due to the `FSDP.clip_grad_norm_` not all-reducing over TP mesh dimension <details> <summary> Loss curves </summary> <img width="732" alt="Screenshot 2024-03-26 at 3 31 19 PM" src="https://github.com/pytorch/torchtrain/assets/31054793/59ec71cc-ad0a-4dd1-b5c6-a8cbf9ab5e85"> </details> **Meta-Device Initialization** - The PyTorch Core guideline is for `module.reset_parameters()` to only initialize parameters/buffers immediately owned by `module` (i.e. `module.parameters(recurse=False)` and `module.buffers(recurse=False)`). - This makes it challenging to specify custom initializations for core modules like `nn.Linear` and `nn.Embedding`. For example, in @lessw2020's depth-wise truncated normal initialization, the `trunc_normal_` standard deviation depends on the layer ID, which is a property of the `TransformerBlock` but affects the child `nn.Linear`s. - To disambiguate, I suggest avoiding the name `reset_parameters()` in the case that we violate the PyTorch Core guideline and instead use a different name (e.g. `init_weights`). **DCP & Save/Load** - Tested 1D and 2D by specifying `checkpoint_folder = "/tmp/checkpoint_andgu` in the `.toml`, training until saving a checkpoint, terminating the run, and restarting the training to load the checkpoint -- the loss after loading looks reasonable
Numeric Parity
1D FSDP
reshard_after_forward=False
for the last transformer block2D FSDP
FSDP.clip_grad_norm_
not all-reducing over TP mesh dimensionLoss curves
Meta-Device Initialization
module.reset_parameters()
to only initialize parameters/buffers immediately owned bymodule
(i.e.module.parameters(recurse=False)
andmodule.buffers(recurse=False)
).nn.Linear
andnn.Embedding
. For example, in @lessw2020's depth-wise truncated normal initialization, thetrunc_normal_
standard deviation depends on the layer ID, which is a property of theTransformerBlock
but affects the childnn.Linear
s.reset_parameters()
in the case that we violate the PyTorch Core guideline and instead use a different name (e.g.init_weights
).DCP & Save/Load
checkpoint_folder = "/tmp/checkpoint_andgu
in the.toml
, training until saving a checkpoint, terminating the run, and restarting the training to load the checkpoint -- the loss after loading looks reasonable