Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma3 (and Paligemma) position_ids 1-indexed? #36856

Open
oceanusxiv opened this issue Mar 20, 2025 · 5 comments · Fixed by #36859
Open

Gemma3 (and Paligemma) position_ids 1-indexed? #36856

oceanusxiv opened this issue Mar 20, 2025 · 5 comments · Fixed by #36859

Comments

@oceanusxiv
Copy link

In the official google implementation of gemma3, all the position_id preparation indicates that position_ids are 0-indexed, the same is true of paligemma in big vision. https://github.com/google-deepmind/gemma/blob/91ee586fbb2f3b8bfeb07b99967008348a229689/gemma/transformer.py#L791.

However, in transformers

# position_ids in Gemma3 are 1-indexed
, it's stated that position_ids are 1-indexed, this seems like a weird discrepancy between model implementations, is this intended?

@zucchini-nlp
Copy link
Member

Hmm...

cc @molbap , I remember you told that was needed to match the original implementation. Can you take a look?

@molbap
Copy link
Contributor

molbap commented Mar 20, 2025

Hey, sure - IIRC it was to make our implementation match jax at the time in PaliGemma because the bos token was added after/had an unusual positioning, and it needed to be added afterwards. Not sure why in Gemma3, taking a look

@molbap
Copy link
Contributor

molbap commented Mar 20, 2025

This is because, in the modular file, Gemma3ForConditionalGeneration inherits from PaliGemmaForConditionalGeneration where this specific fix happens. But it is specific to PaliGemma input ordering and should not have been propagated to Gemma3. I'm opening a PR to change it, although it's a global shift of position ids so it should not change much for RoPE and subsequent logits. cc @gante as we discussed this last time

@oceanusxiv
Copy link
Author

@molbap I'm curious for paligemma also, so far as I could tell, all the Jax implementations of paligemma I've seen also do 0-indexing, such as the implementation in https://github.com/google-research/big_vision/blob/main/big_vision/models/proj/paligemma/paligemma.py, which does 0-indexing. Would you happen to know which Jax implementation was the reference here which had 1-indexing?

@molbap
Copy link
Contributor

molbap commented Mar 24, 2025

There is none: it's not per se related to the original implementation as it was to a tokenizer issue IIRC, and this was a quickfix at the time which ended up staying there. We had 100% logit matching with the original implementation though. You're right to bring this up, I'll reopen the issue in order to remember to investigate.

@molbap molbap reopened this Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants