Skip to content

Conversation

david-stan
Copy link
Contributor

@david-stan david-stan commented Sep 12, 2025

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Docs change / refactoring / dependency upgrade

Motivation and Context / Related issue

It prevents TypeError: DPLossFastGradientAdaptiveClipping.__call__() got an unexpected keyword argument 'vocab_size' error from triggering when assigning DPLossFastGradientAdaptiveClipping or DPLossFastGradientClipping to the .loss_function property of any PreTrainedModel.

Every PreTrainedModel.loss_function() call expects vocab_size amongst it's keyword arguments:

# transformers.models.gpt2.modeling_gpt2.py:1099
# Flatten the tokens
loss = self.loss_function(
        logits,
        labels,
        vocab_size=self.config.vocab_size,
        **kwargs,
    )

Meanwhile, DPLossFastGradientAdaptiveClipping.__call__ and DPLossFastGradientClipping.__call__ don't have this keyword argument vocab_size in their signature. vocab size is later needed for tensor flattening:

def ForCausalLMLoss(
    logits,
    labels,
    vocab_size: int,
    num_items_in_batch: Optional[torch.Tensor] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()

    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    logits = logits.view(-1, vocab_size) <------ used here

How Has This Been Tested (if it applies)

Tested and trained on transformers' GPT2LMHeadModel with LoRA and 4B parameter Llama LlamaForCausalLM model, purposefully targeting different model architectures.

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

Copy link

meta-cla bot commented Sep 12, 2025

Hi @david-stan!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

…t passing to be compatible with other model architectures.
@david-stan david-stan changed the title Add support for passing additional kwargs to per-sample loss functions Add support for passing args and kwargs to per-sample loss functions Sep 12, 2025
Copy link

meta-cla bot commented Sep 12, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2025
Copy link
Contributor

@iden-kalemaj iden-kalemaj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review automatically exported from Phabricator review in Meta.

@facebook-github-bot
Copy link
Contributor

@aparna-aketi merged this pull request in c9032e9.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants