Skip to content

Commit 9a3e0ca

Browse files
committed
multimodal
1 parent b446733 commit 9a3e0ca

File tree

6 files changed

+261
-146
lines changed

6 files changed

+261
-146
lines changed

docs/source/en/model_doc/bert.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ rendered properly in your Markdown viewer.
2828

2929
[BERT](https://huggingface.co/papers/1810.04805) is a bidirectional transformer pretrained on unlabeled text to predict masked tokens in a sentence and to predict whether one sentence follows another. The main idea is that by randomly masking some tokens, the model can train on text to the left and right, giving it a more thorough understanding. BERT is also very versatile because its learned language representations can be adapted for other NLP tasks by fine-tuning an additional layer or head.
3030

31-
You can find all the original BERT checkpoints under the [BERT collection](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc).
31+
You can find all the original BERT checkpoints under the BERT [collection](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc).
3232

3333
> [!TIP]
3434
> Click on the BERT models in the right sidebar for more examples of how to apply BERT to different language tasks.

docs/source/en/model_doc/gemma3.md

+136-85
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,63 @@ rendered properly in your Markdown viewer.
1515
1616
-->
1717

18-
# Gemma3
18+
<div style="float: right;">
19+
<div class="flex flex-wrap space-x-1">
20+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
</div>
23+
</div>
1924

20-
## Overview
25+
# Gemma 3
2126

22-
The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
27+
[Gemma 3](https://goo.gle/Gemma3Report) is a multimodal model, available in pretrained and instruction-tuned variants, available in 1B, 13B, and 27B parameters. The architecture is mostly the same as the previous Gemma versions. The key differences are alternating 5 local sliding window self-attention layers for every global self-attention layer, support for a longer context length of 128K tokens, and a [SigLip](./siglip) encoder that can "pan & scan" high-resolution images to prevent information in images from disappearing.
2328

24-
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq).
29+
The instruction-tuned Gemma 3 model was post-trained with knowledge distillation and reinforcement learning.
2530

31+
You can find all the original Gemma 3 checkpoints under the [Gemma 3](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b) release.
2632

27-
## Usage tips
33+
> [!TIP]
34+
> Click on the Gemma 3 models in the right sidebar for more examples of how to apply Gemma to different vision and language tasks.
2835
36+
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
2937

30-
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
31-
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
32-
- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
33-
- The text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
34-
- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it.
38+
<hfoptions id="usage">
39+
<hfoption id="Pipeline">
3540

41+
```py
42+
import torch
43+
from transformers import pipeline
3644

37-
### Image cropping for high resolution images
38-
39-
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
45+
pipeline = pipeline(
46+
task="image-text-to-text",
47+
model="google/gemma-3-4b-pt",
48+
device=0,
49+
torch_dtype=torch.bfloat16
50+
)
51+
pipeline(
52+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
53+
text="<start_of_image> What is shown in this image?"
54+
)
55+
```
4056

41-
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
57+
</hfoption>
58+
<hfoption id="AutoModel">
4259

43-
```python
60+
```py
61+
import torch
62+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
4463

45-
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
64+
model = Gemma3ForConditionalGeneration.from_pretrained(
65+
"google/gemma-3-4b-it",
66+
torch_dtype=torch.bfloat16,
67+
device_map="auto",
68+
attn_implementation="sdpa"
69+
)
70+
processor = AutoProcessor.from_pretrained(
71+
"google/gemma-3-4b-it",
72+
padding_side="left"
73+
)
4674

47-
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
4875
messages = [
4976
{
5077
"role": "system",
@@ -54,7 +81,7 @@ messages = [
5481
},
5582
{
5683
"role": "user", "content": [
57-
{"type": "image", "url": url},
84+
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
5885
{"type": "text", "text": "What is shown in this image?"},
5986
]
6087
},
@@ -65,24 +92,36 @@ inputs = processor.apply_chat_template(
6592
return_dict=True,
6693
return_tensors="pt",
6794
add_generation_prompt=True,
68-
do_pan_and_scan=True,
69-
).to(model.device)
95+
).to("cuda")
7096

97+
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
98+
print(processor.decode(output[0], skip_special_tokens=True))
7199
```
72100

101+
</hfoption>
102+
</hfoptions>
73103

74-
## Usage Example
104+
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
75105

76-
### Single-image Inference
106+
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
77107

78-
```python
79-
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
108+
```py
109+
# pip install torchao
110+
import torch
111+
from transformers import TorchAoConfig, Gemma3ForConditionalGeneration, AutoProcessor
80112

81-
model_id = "google/gemma-3-4b-it"
82-
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
83-
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
113+
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
114+
model = Gemma3ForConditionalGeneration.from_pretrained(
115+
"google/gemma-3-27b-it",
116+
torch_dtype=torch.bfloat16,
117+
device_map="auto",
118+
quantization_config=quantization_config
119+
)
120+
processor = AutoProcessor.from_pretrained(
121+
"google/gemma-3-27b-it",
122+
padding_side="left"
123+
)
84124

85-
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
86125
messages = [
87126
{
88127
"role": "system",
@@ -92,7 +131,7 @@ messages = [
92131
},
93132
{
94133
"role": "user", "content": [
95-
{"type": "image", "url": url},
134+
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
96135
{"type": "text", "text": "What is shown in this image?"},
97136
]
98137
},
@@ -103,69 +142,81 @@ inputs = processor.apply_chat_template(
103142
return_dict=True,
104143
return_tensors="pt",
105144
add_generation_prompt=True,
106-
).to(model.device)
145+
).to("cuda")
107146

108-
output = model.generate(**inputs, max_new_tokens=50)
109-
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
147+
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
148+
print(processor.decode(output[0], skip_special_tokens=True))
110149
```
111150

112-
### Multi-image Inference
113-
114-
```python
115-
model_id = "google/gemma-3-4b-it"
116-
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
117-
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
118-
119-
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
120-
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
121-
messages = [
122-
{
123-
"role": "system",
124-
"content": [
125-
{"type": "text", "text": "You are a helpful assistant."}
126-
]
127-
},
128-
{
129-
"role": "user", "content": [
130-
{"type": "image", "url": url_cow},
131-
{"type": "image", "url": url_stop},
132-
{"type": "text", "text": "Are these two images identical?"},
133-
]
134-
},
135-
]
136-
inputs = processor.apply_chat_template(
137-
messages,
138-
tokenize=True,
139-
return_dict=True,
140-
return_tensors="pt",
141-
add_generation_prompt=True,
142-
).to(model.device)
143-
144-
output = model.generate(**inputs, max_new_tokens=50)
145-
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
146-
147-
```
148-
149-
### Text-only inference
150-
151-
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
152-
```python
153-
from transformers import AutoTokenizer, Gemma3ForCausalLM
154-
155-
model_id = "google/gemma-3-1b-it"
156-
157-
tokenizer = AutoTokenizer.from_pretrained(model_id)
158-
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
159-
160-
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
161-
162-
outputs = model.generate(**input_ids, max_new_tokens=100)
163-
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
151+
Use the [`~transformers.utils.AttentionMaskVisualizer`] to better understand what tokens the model can and cannot attend to.
164152

165-
print(text)
153+
```py
154+
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
166155

156+
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
157+
visualizer("<img>What is shown in this image?")
167158
```
168159

160+
## Notes
161+
162+
- Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs.
163+
- Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
164+
165+
```py
166+
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
167+
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
168+
169+
messages =[
170+
{
171+
"role": "system",
172+
"content": [
173+
{"type": "text", "text": "You are a helpful assistant."}
174+
]
175+
},
176+
{
177+
"role": "user",
178+
"content": [
179+
{"type": "image", "url": url_cow},
180+
{"type": "image", "url": url_cat},
181+
{"type": "text", "text": "Which image is cuter?"},
182+
]
183+
},
184+
]
185+
```
186+
- Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
187+
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
188+
- By default, the images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.
189+
190+
```diff
191+
inputs = processor.apply_chat_template(
192+
messages,
193+
tokenize=True,
194+
return_dict=True,
195+
return_tensors="pt",
196+
add_generation_prompt=True,
197+
+ do_pan_and_scan=True,
198+
).to("cuda")
199+
```
200+
- For text-only inputs, use [`AutoModelForCausalLM`] instead to skip loading the vision components and save resources.
201+
202+
```py
203+
import torch
204+
from transformers import AutoModelForCausalLM, AutoTokenizer
205+
206+
tokenizer = AutoTokenizer.from_pretrained(
207+
"google/gemma-3-1b-pt",
208+
)
209+
model = AutoModelForCausalLM.from_pretrained(
210+
"google/gemma-3-1b-pt",
211+
torch_dtype=torch.bfloat16,
212+
device_map="auto",
213+
attn_implementation="sdpa"
214+
)
215+
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
216+
217+
output = model.generate(**input_ids, cache_implementation="static")
218+
print(tokenizer.decode(output[0], skip_special_tokens=True))
219+
```
169220

170221
## Gemma3ImageProcessor
171222

docs/source/en/model_doc/llama.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ You can find all the original Llama checkpoints under the [Huggy Llama](https://
3333
> [!TIP]
3434
> Click on the Llama models in the right sidebar for more examples of how to apply Llama to different language tasks.
3535
36-
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
36+
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.
3737

3838
<hfoptions id="usage">
3939
<hfoption id="Pipeline">
@@ -107,7 +107,7 @@ output = model.generate(**input_ids, cache_implementation="static")
107107
print(tokenizer.decode(output[0], skip_special_tokens=True))
108108
```
109109

110-
Use the `visualize_attention_mask` method to better understand what tokens the model can and cannot attend to.
110+
Use the [`~transformers.utils.AttentionMaskVisualizer`] utility to better understand what tokens the model can and cannot attend to.
111111

112112
```py
113113
from transformers.utils.attention_visualizer import AttentionMaskVisualizer

docs/source/en/model_doc/llama2.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ rendered properly in your Markdown viewer.
2828

2929
Llama 2-Chat is trained with supervised fine-tuning (SFT), and reinforcement learning with human feedback (RLHF) - rejection sampling and proximal policy optimization (PPO) - is applied to the fine-tuned model to align the chat model with human preferences.
3030

31-
You can find all the original Llama 2 checkpoints under the [Llama 2 Family collection](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b).
31+
You can find all the original Llama 2 checkpoints under the [Llama 2 Family](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b) collection.
3232

3333
> [!TIP]
3434
> Click on the Llama 2 models in the right sidebar for more examples of how to apply Llama to different language tasks.
@@ -107,7 +107,7 @@ output = model.generate(**input_ids, cache_implementation="static")
107107
print(tokenizer.decode(output[0], skip_special_tokens=True))
108108
```
109109

110-
Use the `visualize_attention_mask` method to better understand what tokens the model can and cannot attend to.
110+
Use the [`~transformers.utils.AttentionMaskVisualizer`] to better understand what tokens the model can and cannot attend to.
111111

112112
```py
113113
from transformers.utils.attention_visualizer import AttentionMaskVisualizer

0 commit comments

Comments
 (0)