Recent advances in diffusion language models (DLMs) have presented a promising alternative to traditional autoregressive large language models (LLMs). However, DLMs still lag behind LLMs in reasoning performance, especially as the number of denoising steps decreases. Our analysis reveals that this shortcoming arises primarily from the independent generation of masked tokens across denoising steps, which fails to capture the token correlation. In this paper, we define two types of token correlation: intra-sequence correlation and inter-sequence correlation, and demonstrate that enhancing these correlations improves reasoning performance. To this end, we propose a Multi-Reward Optimization (MRO) approach, which encourages DLMs to consider the token correlation during the denoising process. More specifically, our MRO approach leverages test-time scaling, reject sampling, and reinforcement learning to directly optimize the token correlation with multiple elaborate rewards. Additionally, we introduce group step and importance sampling strategies to mitigate reward variance and enhance sampling efficiency. Through extensive experiments, we demonstrate that MRO not only improves reasoning performance but also achieves significant sampling speedups while maintaining high performance on reasoning benchmarks.
Preparation:
pip install -e ".[torch,metrics]"- Training with RL (Policy Gradient)
llamafactory-cli train examples/train_full/llada_full_reinforce.yaml- Training with Rejection Sampling
llamafactory-cli train examples/train_full/llada_full_rejection_sampling.yaml- Inference in Normal
cd diffusion_llm_evaluation/
python eval_datasets.py \
-m /path/to/model -d gsm8k -r /path/to/res.jsonl \
--batch-size 16 --steps 128 --gen-length 256 --block-length 8 \
--temperature 0.0 --shot 0 \
--cards 0,1,2,3,4,5,6,7 --cards-per-model 2- Inference with Test-Time Scaling
cd diffusion_llm_evaluation/
python eval_datasets.py \
-m /path/to/model -d gsm8k -r /path/to/res.jsonl \
--batch-size 16 --steps 128 --gen-length 256 --block-length 8 \
--temperature 1.0 --shot 0 \
--cards 0,1,2,3,4,5,6,7 --cards-per-model 2 \
--test-time-scalingx is the input of the model, consisting with prompt and masked response, like:
[prompt_token_0] [prompt_token_1] [prompt_token_2] ... [mask] [mask] [mask] ...
x_temp is the output of one denoising step, replacing some [mask] with predicted tokens, like:
[prompt_token_0] ... [mask] [predicted_token_0] [mask] ... [predicted_token_1] [mask] ...
transfer_elements_idx is a list containing the index of predicted tokens, like:
[12, 16] which means the 12th and the 16th token is predicted from [mask] in the current denoising step.
x, x_temp, transfer_elements_idx = ...
x_temp_logits = model(x_temp).logits
transfer_elements_p = [F.softmax(x_temp_logits, ...)[idx][x_temp[idx]].item() for idx in transfer_elements_idx]
reward = np.mean(transfer_elements_p)res_str is the generated response of the model.
The perplexity is computed based on package lmppl.
ppl_scorer = lmppl.LM(model_path)
ppl_reward = ppl_scorer.get_perplexity([res_str])
ppl_reward = (100 - ppl_reward[0]) / 100
reward = max(0, ppl_reward)format_reward(...) function check whether the response of model matches the format of <think>...</think><answer>...</answer>
def format_reward(text_list, **kwargs):
pattern = r"^<think>\n.*?\n</think>\n<answer>\n(.*?)\n</answer>$"
completion_contents = text_list
matches = [re.findall(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if len(match)>0 else 0.0 for match in matches], matchesThe reward is 2 if and only if the response matches the format and the answer is correct.
group_format_reward, matches = format_reward([res_str])
gold = ... # parse gold answer
answer = ... # parse matches
reward = int(verify(gold, answer) or (ground_truth in matches[-1]))The pseudocode of test-time scaling is as follows:
FUNCTION test_time_scaling(model, prompt, total_steps, search_every_steps, beam_size, top_k):
// Initialize beam with fully masked response
beam = [fully_masked_response]
FOR step FROM 0 TO total_steps BY search_every_steps:
new_beam = []
FOR each candidate in beam:
FOR i IN 1..beam_size:
// Generate next step candidates
candidate_next = model.generate_step(candidate, step, step+search_every_steps)
// Calculate rewards for candidate
confidence_reward = calculate_confidence_reward(candidate_next)
format_reward = check_format(candidate_next)
accuracy_reward = check_answer_accuracy(candidate_next, ground_truth)
ppl_reward = calculate_perplexity(candidate_next)
total_reward = confidence_reward + format_reward + accuracy_reward + ppl_reward
// Store candidate with its reward
new_beam.append((candidate_next, total_reward))
// Keep top-k candidates
sort new_beam by total_reward descending
beam = [candidate for candidate, _ in new_beam[:top_k]]
RETURN highest_reward_candidate from beam...
max_reward, max_reward_logprobs = -inf, None
FOR sampling_idx IN 1..sampling_size:
logprobs, response = generate(model, prompt)
reward = compute_reward(response)
IF reward > max_reward:
max_reward = reward
max_reward_logprobs = logprobs
loss = - max_reward * max_reward_logprobs
model.backward(loss)
...The only difference between policy gradient above and rejection sampling is that the loss is not weighted by reward.
The pseudocode of rejection sampling during training is as follows:
// same with policy gradient above
...
loss = -max_reward_logprobs
model.backward(loss)
...@misc{wang2025mro,
title={MRO: Enhancing Reasoning in Diffusion Language Models via Multi-Reward Optimization},
author={Chenglong Wang and Yang Gan and Hang Zhou and Chi Hu and Yongyu Mu and Kai Song and Murun Yang and Bei Li and Chunliang Zhang and Tongran Liu and Jingbo Zhu and Zhengtao Yu and Tong Xiao},
year={2025},
eprint={2510.21473},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2510.21473},
}