|
25 | 25 |
|
26 | 26 | logger = logging.get_logger(__name__)
|
27 | 27 |
|
| 28 | +DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { |
| 29 | + "dangerous": ( |
| 30 | + "No Dangerous Content: The image shall not contain content that" |
| 31 | + " facilitates or encourages activities that could cause real-world harm" |
| 32 | + " (e.g., building firearms and explosive devices, promotion of" |
| 33 | + " terrorism, instructions for suicide)." |
| 34 | + ), |
| 35 | + "sexual": ( |
| 36 | + "No Sexually Explicit content: The image shall not contain content that" |
| 37 | + " depicts explicit or graphic sexual acts (e.g., pornography, erotic" |
| 38 | + " nudity, depictions of rape or sexual assault)." |
| 39 | + ), |
| 40 | + "violence": ( |
| 41 | + "No Violence/Gore content: The image shall not contain content that" |
| 42 | + " depicts shocking, sensational, or gratuitous violence (e.g.," |
| 43 | + " excessive blood and gore, gratuitous violence against animals," |
| 44 | + " extreme injury or moment of death)." |
| 45 | + ), |
| 46 | +} |
| 47 | + |
28 | 48 |
|
29 | 49 | class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
|
30 | 50 | policies: Optional[Sequence[str]]
|
@@ -65,10 +85,10 @@ def __init__(
|
65 | 85 | the base policies ShieldGemma was trained on.
|
66 | 86 | """
|
67 | 87 | super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
|
68 |
| - if policy_definitions: |
69 |
| - self.policy_definitions = policy_definitions |
| 88 | + if policy_definitions is None: |
| 89 | + self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES |
70 | 90 | else:
|
71 |
| - self.policy_definitions = {} |
| 91 | + self.policy_definitions = policy_definitions |
72 | 92 |
|
73 | 93 | def __call__(
|
74 | 94 | self,
|
@@ -138,6 +158,7 @@ def __call__(
|
138 | 158 | if (policies := kwargs.get("policies")) is None:
|
139 | 159 | policies = list(policy_definitions.keys())
|
140 | 160 |
|
| 161 | + # TODO(ryanmullins): Support images from PIL or URLs. |
141 | 162 | messages = []
|
142 | 163 | expanded_images = []
|
143 | 164 | for img in images:
|
|
0 commit comments