Skip to content

fix bug when using DP in trl, the batch size of input and output dism… #38938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

kaixuanliu
Copy link
Contributor

No description provided.

@kaixuanliu
Copy link
Contributor Author

kaixuanliu commented Jun 20, 2025

Steps to reproduce the bug:

git clone https://github.com/huggingface/trl.git
cd trl
git checkout 3ef9faf257
pip install .
export CUDA_VISIBLE_DEVICES=0,1,2
pytest -sv -rA tests/slow/test_sft_slow.py::SFTTrainerSlowTester::test_train_offloading_0_trl_internal_testing_tiny_LlamaForCausalLM_3_2

it will fail and return error:

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute training loss and additionally compute token accuracies
        """
        mode = "train" if self.model.training else "eval"
        (loss, outputs) = super().compute_loss(
            model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        )
        if mode == "train":
            # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q,
            # cu_seq_lens_k, and max_length_k, max_length_q and position_ids.
            if "attention_mask" in inputs:
                num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
            elif "position_ids" in inputs:
                local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device)
                num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item()
            else:
                raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
            self._total_train_tokens += num_tokens_in_batch
        self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

        # Compute token accuracy if we have labels and if the model is not using Liger (no logits)
        if "labels" in inputs and not self.args.use_liger_kernel:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = inputs["labels"][..., 1:].contiguous()

            # Get predictions
            predictions = shift_logits.argmax(dim=-1)

            # Create mask for non-padding tokens (assuming ignore_index is -100)
            mask = shift_labels != -100

            # Calculate accuracy only on non-padding tokens
>           correct_predictions = (predictions == shift_labels) & mask
E           RuntimeError: The size of tensor a (2) must match the size of tensor b (6) at non-singleton dimension 0

It crashes as num_items_in_batch in L3837 is a 1-D tensor, and it cannot be scattered to multi-gpus successfully, hence although the input bs=6 in L3839, the output bs will be 2, and hence the test case fails.

@kaixuanliu
Copy link
Contributor Author

@zach-huggingface, @SunMarc and @qgallouedec, pls help review

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Can you add a test that cover this specific case ?

@kaixuanliu
Copy link
Contributor Author

@SunMarc , Hi thx for advice. I think the existing one is OK for this case:
pytest -sv -rA tests/trainer/test_trainer.py::TrainerIntegrationTest::test_num_batches_in_training_with_gradient_accumulation
I added related assertion in latest commit. Pls help check if it is OK.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@yao-matrix
Copy link
Contributor

@kaixuanliu , CI has failed cases, pls take a look

@kaixuanliu
Copy link
Contributor Author

@yao-matrix , Updated the code and the failed case passed. I also double checked the failed case on my own machine. @SunMarc Can you help review again? thx!

@kaixuanliu
Copy link
Contributor Author

@SunMarc Hi, this is a 2 weeks ago PR, can you help review it? Many thanks!

@SunMarc SunMarc requested a review from qgallouedec July 15, 2025 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants