Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shieldgemma2 #36678

Merged
merged 14 commits into from
Mar 20, 2025
Merged

Shieldgemma2 #36678

merged 14 commits into from
Mar 20, 2025

Conversation

RyanMullins
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot marked this pull request as draft March 12, 2025 15:43
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@RyanMullins RyanMullins marked this pull request as ready for review March 12, 2025 18:22
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, small nits but good to go otherwise!

@ain-soph
Copy link
Contributor

ain-soph commented Mar 17, 2025

@RyanMullins Submit a PR for this branch as bug fixes.
RyanMullins#2

@ain-soph
Copy link
Contributor

There is still an logging issue that outputs incorrect info:

from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification

model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)

custom_policies = {
    "key_a": "descrition_a",
    "key_b": "descrition_b",
}

inputs = processor(
    images=[image],
    custom_policies=custom_policies,
    policies=["dangerous", "key_b"],
    return_tensors="pt",
).to(model.device)

output = model(**inputs)
print(output.probabilities)
$ python test.py
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00,  3.12s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Keyword argument `custom_policies` is not a valid argument for this processor and will be ignored.
Keyword argument `policies` is not a valid argument for this processor and will be ignored.
tensor([[5.3613e-11, 1.0000e+00],
        [3.7379e-09, 1.0000e+00]], device='cuda:1', grad_fn=<ToCopyBackward0>)

As you can see above, the argument custom_policies and policies are already functioning (the output tensor is only 2 classes now), but still get

Keyword argument `custom_policies` is not a valid argument for this processor and will be ignored.
Keyword argument `policies` is not a valid argument for this processor and will be ignored.

@ain-soph
Copy link
Contributor

ain-soph commented Mar 17, 2025

And currently it's confusing about the output tensor.

  • What's the meaning of each dimension for the output tensor shape?
  • What's the meaning of each number and which harmful category is it mapped to? I assume the second column is not necessary since it seems to be 1 - first_column

It would be nice if we have something for explanation in the readme / API doc.

@RyanMullins
Copy link
Contributor Author

@ain-soph thanks for the contribs. For some reason I can't reply to your comments, but I've added docstrings to address your questions about the shape of the output tensors, and I added an explicit ShieldGemma2ProcessorKwargs to address the logs you saw.

@ghunkins
Copy link

When running the default example, an IndexError is raised. Passing in custom_policies fixes that.

Example

from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification

model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=[image], return_tensors="pt").to(model.device)

output = model(**inputs)
print(output.probabilities)

Stack Trace

Loadingcheckpointshards: 100%2/2 [00:06<00:00,  2.86s/it]
WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-26-290947fc3325>](https://localhost:8080/#) in <cell line: 0>()
     10 image = Image.open(requests.get(url, stream=True).raw)
     11 
---> 12 inputs = processor(images=[image], return_tensors="pt").to(model.device)
     13 
     14 output = model(**inputs)

1 frames
[/usr/local/lib/python3.11/dist-packages/transformers/processing_utils.py](https://localhost:8080/#) in apply_chat_template(self, conversation, chat_template, **kwargs)
   1326 
   1327         if isinstance(conversation, (list, tuple)) and (
-> 1328             isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
   1329         ):
   1330             is_batched = True

IndexError: list index out of range

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! THanks for iterating!

@RyanMullins
Copy link
Contributor Author

@ghunkins thanks for the bug report!

The problem was that the files on the Hub were out of date and the processor.json was missing the policy_definitions property. I was using updated files locally, which is why I didn't see this until I repro'd your example in Colab. The files are now updated on the Hub.

@ArthurZucker accidentally suggested a local solution to the problem by providing a default set of policies. So this should be fully fixed now and ready for merging.

@ghunkins
Copy link

Brilliant, thanks for the amazing work on this @RyanMullins !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RyanMullins let's just fix cis! 🤗

@ArthurZucker
Copy link
Collaborator

Do you need help on my side for the last remaining tests?

@RyanMullins RyanMullins force-pushed the shieldgemma2 branch 3 times, most recently from 03efe5e to 36135cf Compare March 20, 2025 13:54
@ArthurZucker ArthurZucker merged commit 487dab1 into huggingface:main Mar 20, 2025
19 of 21 checks passed
@ain-soph
Copy link
Contributor

nit: I just tested and there is the same warning for policy_definitions.
Some kwargs in processor config are unused and will not have any effect: policy_definitions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants