-
Notifications
You must be signed in to change notification settings - Fork 362
Integrate Multi-Token Prediction (MTP) Training objective #1837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: parambole/mtp_refactor
Are you sure you want to change the base?
Integrate Multi-Token Prediction (MTP) Training objective #1837
Conversation
da99edb
to
affadb8
Compare
59ddf3f
to
27ae66f
Compare
ccebc8c
to
b1b2d95
Compare
eca208c
to
cd41461
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! But I think we could load DeepSeek v3 ckpt before we claim the support.
layer_types = maxtext_utils.get_decoder_layers(self.config) | ||
# For MTP, we use the primary (usually dense) transformer block blueprint | ||
# to ensure architectural consistency. By convention, this is the first in the list. | ||
mtp_layer = layer_types[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For DeepSeek, it's mixed layers ([deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
. Could you confirm if this is dense or moe layer? It is a moe layer if I recall correctly.
# 2. The `shared_embedding` for both embedding future tokens and for its final | ||
# logit projection. | ||
# Its only effect is to "sow" these losses; it does not alter the primary logits output. | ||
if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add assertion in pyconfig.py that mtp_num_layers>1 for inference/serving is not supported?
Dependency: This PR depends on and must be merged after the refactoring in Refactor: Decouple Core Transformer Blocks #1852.
PR: Multi-Token Prediction (MTP) Integration
TL;DR
MultiTokenPredictionBlock
that runs after the main decoder stack during training. It computes an additional loss term which is added to the main loss.Detailed Description
Background and Motivation
Standard language models are trained on a next-token prediction objective. Multi-Token Prediction (MTP) enhances this by adding an auxiliary task: from each position in a sequence, the model also learns to predict several tokens into the future. This encourages the model to develop richer internal representations and can lead to significant improvements in sample efficiency and final model performance.
This implementation follows the sequential prediction model, where the prediction of token
t+k+1
is causally dependent on the layer that predicted tokent+k
.Architectural Changes
To integrate this feature cleanly and robustly, several key architectural changes were made:
The MTP Module (
layers/multi_token_prediction.py
) A new file was created to house all MTP-specific logic:MultiTokenPredictionLayer
: A single block responsible for one step of future prediction. It normalizes its inputs, projects them, and processes them through a standard transformer layer.MultiTokenPredictionBlock
: This module orchestrates the entire MTP process. It contains afor
loop that runs formtp_num_layers
, instantiating a uniqueMultiTokenPredictionLayer
for each step and maintaining the sequential flow of the hidden state.Integration with
Transformer
(layers/models.py
) The mainTransformer
model was modified to facilitate the MTP "side-car":Decoder
's__call__
method now returns both themain_logits
and the rawfinal_hidden_state
(pre-normalization). This makes the dependency explicit.Transformer
'ssetup
method now instantiates theMultiTokenPredictionBlock
, passing it the correctDecoderLayer
blueprint to ensure architectural consistency.Transformer
's__call__
method calls theMTPBlock
only during training (model_mode == MODEL_MODE_TRAIN
), explicitly passing it the dependencies it needs (final_hidden_state
,shared_embedding
, etc.).Loss Calculation (
train.py
) The auxiliary loss is aggregated without changing theTransformer
's return signature by using Flax'ssow
mechanism:MultiTokenPredictionBlock
callsself.sow('mtp_losses', 'losses', ...)
for each layer's calculated loss. This is guarded by aif not self.is_initializing()
check to prevent running during model initialization.loss_fn
intrain.py
is now responsible for "reaping" these values by making the'mtp_losses'
collection mutable during the training.apply
call.maxtext_utils.get_nested_value
utility and the explicit path (mtp_losses
,mtp_block
,losses
).mtp_loss_scaling_factor
, and adds it to the main loss before backpropagation. Themtp_loss
is also added to the training and evaluation metrics for logging.Configuration
This feature is controlled via two new parameters in
base.yml
:mtp_num_layers
: (int
, default:0
) The number of auxiliary prediction layers to use. Set to a positive integer to enable MTP.mtp_loss_scaling_factor
: (float
, default:0.1
) The weighting factor for the final MTP loss.How to Use: To enable MTP with 4 prediction heads and a 15% loss contribution, add the following to your config file:
YAML
Testing Strategy
A new test file,
MaxText/tests/multi_token_prediction_test.py
, has been added with a new test class,MultiTokenPredictionBlockTest
, to ensure the implementation is robust.The testing follows these key principles:
MTPBlockTestModel
is used to wrap theMultiTokenPredictionBlock
and its dependencies. This allows Flax's.init()
to handle all parameter creation automatically and correctly, which is a robust pattern seen in other MaxText tests.test_sow_functionality
): This primary test verifies that when the block is run in training mode (mutable=['mtp_losses']
), it correctly sows the expected number of losses and weights.test_no_sow_during_init
): This test confirms that no losses are sown during the.init()
call by leveraging theif not self.is_initializing()
check in the application code.roll_and_mask
utility was also added tomaxtext_utils_test.py
.Checklist
Before submitting this PR, please make sure (put X in square brackets):