diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 23c8bc92b2f1..2a9aa81608e7 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -27,3 +27,4 @@ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache + from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache diff --git a/src/diffusers/hooks/text_kv_cache.py b/src/diffusers/hooks/text_kv_cache.py new file mode 100644 index 000000000000..53777ae185a0 --- /dev/null +++ b/src/diffusers/hooks/text_kv_cache.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from .hooks import HookRegistry, ModelHook + + +_TEXT_KV_CACHE_HOOK = "text_kv_cache" + + +@dataclass +class TextKVCacheConfig: + """Enable exact (lossless) text K/V caching for transformer models. + + Pre-computes per-block text key and value projections once before the + denoising loop and reuses them across all steps. The cached values are keyed by + the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive + and negative prompts (when ``true_cfg_scale > 1``) are handled correctly. + """ + + pass # no hyperparameters needed — cache is always exact + + +class TextKVCacheHook(ModelHook): + """Block-level hook that caches (txt_key, txt_value) per unique prompt.""" + + _is_stateful = True + + def __init__(self): + super().__init__() + # Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value) + self.kv_cache: dict[int, tuple] = {} + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus + + # --- extract encoder_hidden_states --- + if "encoder_hidden_states" in kwargs: + encoder_hidden_states = kwargs["encoder_hidden_states"] + else: + # positional: (hidden_states, encoder_hidden_states, temb, ...) + encoder_hidden_states = args[1] + + # --- extract image_rotary_emb --- + if "image_rotary_emb" in kwargs: + image_rotary_emb = kwargs.get("image_rotary_emb") + elif len(args) > 3: + image_rotary_emb = args[3] + else: + image_rotary_emb = None + + ptr = encoder_hidden_states.data_ptr() + + if ptr not in self.kv_cache: + context = module.encoder_proj(encoder_hidden_states) + + attn = module.attn + head_dim = attn.inner_dim // attn.heads + num_kv_heads = attn.inner_kv_dim // head_dim + + txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1)) + txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + if image_rotary_emb is not None: + _, txt_freqs = image_rotary_emb + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) + + self.kv_cache[ptr] = (txt_key, txt_value) + + txt_key, txt_value = self.kv_cache[ptr] + + # Inject cached k/v — block sees cached_txt_key and skips encoder_proj too + attn_kwargs = kwargs.get("attention_kwargs") or {} + attn_kwargs["cached_txt_key"] = txt_key + attn_kwargs["cached_txt_value"] = txt_value + kwargs["attention_kwargs"] = attn_kwargs + + return self.fn_ref.original_forward(*args, **kwargs) + + def reset_state(self, module: torch.nn.Module): + self.kv_cache.clear() + return module + + +def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None: + from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock + + for _, submodule in module.named_modules(): + if isinstance(submodule, NucleusMoEImageTransformerBlock): + hook = TextKVCacheHook() + registry = HookRegistry.check_if_exists_or_initialize(submodule) + registry.register_hook(hook, _TEXT_KV_CACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 5f9587a1b4de..3bca773d8344 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,11 +41,12 @@ def enable_cache(self, config) -> None: Enable caching techniques on the model. Args: - config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`): + config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`): The configuration for applying the caching technique. Currently supported caching techniques are: - [`~hooks.PyramidAttentionBroadcastConfig`] - [`~hooks.FasterCacheConfig`] - [`~hooks.FirstBlockCacheConfig`] + - [`~hooks.TextKVCacheConfig`] Example: @@ -71,11 +72,13 @@ def enable_cache(self, config) -> None: MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_text_kv_cache, ) if self.is_cache_enabled: @@ -89,6 +92,8 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, MagCacheConfig): apply_mag_cache(self, config) + elif isinstance(config, TextKVCacheConfig): + apply_text_kv_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) elif isinstance(config, TaylorSeerCacheConfig): @@ -106,12 +111,14 @@ def disable_cache(self) -> None: MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK + from ..hooks.text_kv_cache import _TEXT_KV_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -129,6 +136,8 @@ def disable_cache(self) -> None: registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TextKVCacheConfig): + registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index 9c2aa17f162a..342919a26148 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -268,6 +268,8 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor | None = None, image_rotary_emb: torch.Tensor | None = None, + cached_txt_key: torch.FloatTensor | None = None, + cached_txt_value: torch.FloatTensor | None = None, ) -> torch.FloatTensor: head_dim = attn.inner_dim // attn.heads num_kv_heads = attn.inner_kv_dim // head_dim @@ -287,7 +289,11 @@ def __call__( img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False) img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False) - if encoder_hidden_states is not None: + if cached_txt_key is not None and cached_txt_value is not None: + txt_key, txt_value = cached_txt_key, cached_txt_value + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + elif encoder_hidden_states is not None: txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) @@ -537,17 +543,18 @@ def forward( gate1 = gate1.clamp(min=-2.0, max=2.0) gate2 = gate2.clamp(min=-2.0, max=2.0) - context = self.encoder_proj(encoder_hidden_states) + # Skip encoder_proj when text K/V are already cached — context won't be used by the processor + attn_kwargs = attention_kwargs or {} + context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states) img_normed = self.pre_attn_norm(hidden_states) img_modulated = img_normed * scale1 - attention_kwargs = attention_kwargs or {} img_attn_output = self.attn( hidden_states=img_modulated, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb, - **attention_kwargs, + **attn_kwargs, ) hidden_states = hidden_states + gate1.tanh() * img_attn_output diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 70d8ec8212ad..650ec145744e 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -574,6 +574,10 @@ def __call__( self._num_timesteps = len(timesteps) self.scheduler.set_begin_index(0) + + if self.transformer.is_cache_enabled: + self.transformer._reset_stateful_cache() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/tests/models/transformers/test_models_transformer_nucleusmoe_image.py b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py new file mode 100644 index 000000000000..edd6de53701a --- /dev/null +++ b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py @@ -0,0 +1,242 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch + +from diffusers import NucleusMoEImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return NucleusMoEImageTransformer2DModel + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), + "moe_enabled": False, + "capacity_factors": [8.0, 8.0], + } + + def get_dummy_inputs(self) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + height = width = 4 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin): + def test_txt_seq_lens_deprecation(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = [inputs["encoder_hidden_states"].shape[1]] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_with_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + # Mask out some text tokens + mask = inputs["encoder_hidden_states_mask"].clone() + mask[:, 4:] = 0 + inputs["encoder_hidden_states_mask"] = mask + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_without_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + +class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRAHotSwap( + NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin +): + """LoRA hot-swapping tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for NucleusMoE Image Transformer.""" diff --git a/tests/pipelines/nucleusmoe_image/__init__.py b/tests/pipelines/nucleusmoe_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py new file mode 100644 index 000000000000..1327bfdcb88b --- /dev/null +++ b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py @@ -0,0 +1,197 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + NucleusMoEImagePipeline, + NucleusMoEImageTransformer2DModel, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = NucleusMoEImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = NucleusMoEImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=4, + joint_attention_dim=16, + axes_dims_rope=(8, 4, 4), + moe_enabled=False, + capacity_factors=[8.0, 8.0], + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * z_dim, + latents_std=[1.0] * z_dim, + # fmt: on + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 8, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + "vocab_size": 151936, + "head_dim": 8, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_channels": 16, + }, + ) + text_encoder = Qwen3VLForConditionalGeneration(config).eval() + processor = Qwen3VLProcessor.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A cat sitting on a mat", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_true_cfg(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["true_cfg_scale"] = 4.0 + inputs["negative_prompt"] = "low quality" + image = pipe(**inputs).images + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_prompt_embeds(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=inputs["prompt"], + device=device, + max_sequence_length=inputs["max_sequence_length"], + ) + + inputs_with_embeds = self.get_dummy_inputs(device) + inputs_with_embeds.pop("prompt") + inputs_with_embeds["prompt_embeds"] = prompt_embeds + inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask + + image = pipe(**inputs_with_embeds).images + self.assertEqual(image[0].shape, (3, 32, 32))