Skip to content

Integrate Multi-Token Prediction (MTP) Training objective #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: parambole/mtp_refactor
Choose a base branch
from

Conversation

parambole
Copy link
Collaborator

@parambole parambole commented Jun 16, 2025

Dependency: This PR depends on and must be merged after the refactoring in Refactor: Decouple Core Transformer Blocks #1852.

PR: Multi-Token Prediction (MTP) Integration

TL;DR

  • What: This PR integrates the Multi-Token Prediction (MTP) auxiliary training objective into MaxText.
  • Why: To improve model performance and training efficiency by densifying training signals, based on the architecture described in the DeepSeek-V3 paper.
  • How: By adding a MultiTokenPredictionBlock that runs after the main decoder stack during training. It computes an additional loss term which is added to the main loss.

Detailed Description

Background and Motivation

Standard language models are trained on a next-token prediction objective. Multi-Token Prediction (MTP) enhances this by adding an auxiliary task: from each position in a sequence, the model also learns to predict several tokens into the future. This encourages the model to develop richer internal representations and can lead to significant improvements in sample efficiency and final model performance.

This implementation follows the sequential prediction model, where the prediction of token t+k+1 is causally dependent on the layer that predicted token t+k.


Architectural Changes

To integrate this feature cleanly and robustly, several key architectural changes were made:

  1. The MTP Module (layers/multi_token_prediction.py) A new file was created to house all MTP-specific logic:

    • MultiTokenPredictionLayer: A single block responsible for one step of future prediction. It normalizes its inputs, projects them, and processes them through a standard transformer layer.
    • MultiTokenPredictionBlock: This module orchestrates the entire MTP process. It contains a for loop that runs for mtp_num_layers, instantiating a unique MultiTokenPredictionLayer for each step and maintaining the sequential flow of the hidden state.
  2. Integration with Transformer (layers/models.py) The main Transformer model was modified to facilitate the MTP "side-car":

    • The Decoder's __call__ method now returns both the main_logits and the raw final_hidden_state (pre-normalization). This makes the dependency explicit.
    • The Transformer's setup method now instantiates the MultiTokenPredictionBlock, passing it the correct DecoderLayer blueprint to ensure architectural consistency.
    • The Transformer's __call__ method calls the MTPBlock only during training (model_mode == MODEL_MODE_TRAIN), explicitly passing it the dependencies it needs (final_hidden_state, shared_embedding, etc.).
  3. Loss Calculation (train.py) The auxiliary loss is aggregated without changing the Transformer's return signature by using Flax's sow mechanism:

    • The MultiTokenPredictionBlock calls self.sow('mtp_losses', 'losses', ...) for each layer's calculated loss. This is guarded by a if not self.is_initializing() check to prevent running during model initialization.
    • The main loss_fn in train.py is now responsible for "reaping" these values by making the 'mtp_losses' collection mutable during the training .apply call.
    • It then retrieves the tuple of sown losses and weights using the existing maxtext_utils.get_nested_value utility and the explicit path (mtp_losses, mtp_block, losses).
    • Finally, it computes the average MTP loss, scales it by mtp_loss_scaling_factor, and adds it to the main loss before backpropagation. The mtp_loss is also added to the training and evaluation metrics for logging.

Configuration

This feature is controlled via two new parameters in base.yml:

  • mtp_num_layers: (int, default: 0) The number of auxiliary prediction layers to use. Set to a positive integer to enable MTP.
  • mtp_loss_scaling_factor: (float, default: 0.1) The weighting factor for the final MTP loss.

How to Use: To enable MTP with 4 prediction heads and a 15% loss contribution, add the following to your config file:

YAML

mtp_num_layers: 4
mtp_loss_scaling_factor: 0.15


Testing Strategy

A new test file, MaxText/tests/multi_token_prediction_test.py, has been added with a new test class, MultiTokenPredictionBlockTest, to ensure the implementation is robust.

The testing follows these key principles:

  • Wrapper Model: A lightweight MTPBlockTestModel is used to wrap the MultiTokenPredictionBlock and its dependencies. This allows Flax's .init() to handle all parameter creation automatically and correctly, which is a robust pattern seen in other MaxText tests.
  • Core Functionality Test (test_sow_functionality): This primary test verifies that when the block is run in training mode (mutable=['mtp_losses']), it correctly sows the expected number of losses and weights.
  • Initialization Test (test_no_sow_during_init): This test confirms that no losses are sown during the .init() call by leveraging the if not self.is_initializing() check in the application code.
  • A new test for the roll_and_mask utility was also added to maxtext_utils_test.py.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from da99edb to affadb8 Compare June 16, 2025 18:30
@parambole parambole marked this pull request as ready for review June 16, 2025 18:31
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch 6 times, most recently from 59ddf3f to 27ae66f Compare June 17, 2025 02:34
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from ccebc8c to b1b2d95 Compare June 19, 2025 19:23
@parambole parambole changed the base branch from main to parambole/mtp_refactor June 19, 2025 19:24
@parambole parambole force-pushed the parambole/maxtext_mtp_training_obective branch from eca208c to cd41461 Compare June 19, 2025 19:48
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! But I think we could load DeepSeek v3 ckpt before we claim the support.

layer_types = maxtext_utils.get_decoder_layers(self.config)
# For MTP, we use the primary (usually dense) transformer block blueprint
# to ensure architectural consistency. By convention, this is the first in the list.
mtp_layer = layer_types[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For DeepSeek, it's mixed layers ([deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]. Could you confirm if this is dense or moe layer? It is a moe layer if I recall correctly.

# 2. The `shared_embedding` for both embedding future tokens and for its final
# logit projection.
# Its only effect is to "sow" these losses; it does not alter the primary logits output.
if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add assertion in pyconfig.py that mtp_num_layers>1 for inference/serving is not supported?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants