-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shieldgemma2 #36678
Shieldgemma2 #36678
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, small nits but good to go otherwise!
src/transformers/models/shieldgemma2/processing_shieldgemma2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/shieldgemma2/processing_shieldgemma2.py
Outdated
Show resolved
Hide resolved
@RyanMullins Submit a PR for this branch as bug fixes.
|
There is still an logging issue that outputs incorrect info: from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
custom_policies = {
"key_a": "descrition_a",
"key_b": "descrition_b",
}
inputs = processor(
images=[image],
custom_policies=custom_policies,
policies=["dangerous", "key_b"],
return_tensors="pt",
).to(model.device)
output = model(**inputs)
print(output.probabilities)
As you can see above, the argument
|
And currently it's confusing about the output tensor.
It would be nice if we have something for explanation in the readme / API doc. |
@ain-soph thanks for the contribs. For some reason I can't reply to your comments, but I've added docstrings to address your questions about the shape of the output tensors, and I added an explicit |
When running the default example, an IndexError is raised. Passing in custom_policies fixes that. Examplefrom PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=[image], return_tensors="pt").to(model.device)
output = model(**inputs)
print(output.probabilities) Stack TraceLoading checkpoint shards: 100%
2/2 [00:06<00:00, 2.86s/it]
WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
[<ipython-input-26-290947fc3325>](https://localhost:8080/#) in <cell line: 0>()
10 image = Image.open(requests.get(url, stream=True).raw)
11
---> 12 inputs = processor(images=[image], return_tensors="pt").to(model.device)
13
14 output = model(**inputs)
1 frames
[/usr/local/lib/python3.11/dist-packages/transformers/processing_utils.py](https://localhost:8080/#) in apply_chat_template(self, conversation, chat_template, **kwargs)
1326
1327 if isinstance(conversation, (list, tuple)) and (
-> 1328 isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
1329 ):
1330 is_batched = True
IndexError: list index out of range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! THanks for iterating!
src/transformers/models/shieldgemma2/processing_shieldgemma2.py
Outdated
Show resolved
Hide resolved
@ghunkins thanks for the bug report! The problem was that the files on the Hub were out of date and the processor.json was missing the @ArthurZucker accidentally suggested a local solution to the problem by providing a default set of policies. So this should be fully fixed now and ready for merging. |
Brilliant, thanks for the amazing work on this @RyanMullins ! |
9b861f5
to
3967fd8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RyanMullins let's just fix cis! 🤗
Do you need help on my side for the last remaining tests? |
03efe5e
to
36135cf
Compare
36135cf
to
c070a1a
Compare
nit: I just tested and there is the same warning for |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.