Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
109 changes: 109 additions & 0 deletions src/diffusers/hooks/text_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

import torch

from .hooks import HookRegistry, ModelHook


_TEXT_KV_CACHE_HOOK = "text_kv_cache"


@dataclass
class TextKVCacheConfig:
"""Enable exact (lossless) text K/V caching for transformer models.

Pre-computes per-block text key and value projections once before the
denoising loop and reuses them across all steps. The cached values are keyed by
the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive
and negative prompts (when ``true_cfg_scale > 1``) are handled correctly.
"""

pass # no hyperparameters needed — cache is always exact


class TextKVCacheHook(ModelHook):
"""Block-level hook that caches (txt_key, txt_value) per unique prompt."""

_is_stateful = True

def __init__(self):
super().__init__()
# Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value)
self.kv_cache: dict[int, tuple] = {}

def new_forward(self, module: torch.nn.Module, *args, **kwargs):
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus

# --- extract encoder_hidden_states ---
if "encoder_hidden_states" in kwargs:
encoder_hidden_states = kwargs["encoder_hidden_states"]
else:
# positional: (hidden_states, encoder_hidden_states, temb, ...)
encoder_hidden_states = args[1]

# --- extract image_rotary_emb ---
if "image_rotary_emb" in kwargs:
image_rotary_emb = kwargs.get("image_rotary_emb")
elif len(args) > 3:
image_rotary_emb = args[3]
else:
image_rotary_emb = None

ptr = encoder_hidden_states.data_ptr()

if ptr not in self.kv_cache:
context = module.encoder_proj(encoder_hidden_states)

attn = module.attn
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim

txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))

if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)

if image_rotary_emb is not None:
_, txt_freqs = image_rotary_emb
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)

self.kv_cache[ptr] = (txt_key, txt_value)

txt_key, txt_value = self.kv_cache[ptr]

# Inject cached k/v — block sees cached_txt_key and skips encoder_proj too
attn_kwargs = kwargs.get("attention_kwargs") or {}
attn_kwargs["cached_txt_key"] = txt_key
attn_kwargs["cached_txt_value"] = txt_value
kwargs["attention_kwargs"] = attn_kwargs

return self.fn_ref.original_forward(*args, **kwargs)

def reset_state(self, module: torch.nn.Module):
self.kv_cache.clear()
return module


def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock

for _, submodule in module.named_modules():
if isinstance(submodule, NucleusMoEImageTransformerBlock):
hook = TextKVCacheHook()
registry = HookRegistry.check_if_exists_or_initialize(submodule)
registry.register_hook(hook, _TEXT_KV_CACHE_HOOK)
11 changes: 10 additions & 1 deletion src/diffusers/models/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def enable_cache(self, config) -> None:
Enable caching techniques on the model.
Args:
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
- [`~hooks.TextKVCacheConfig`]
Example:
Expand All @@ -71,11 +72,13 @@ def enable_cache(self, config) -> None:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
apply_text_kv_cache,
)

if self.is_cache_enabled:
Expand All @@ -89,6 +92,8 @@ def enable_cache(self, config) -> None:
apply_first_block_cache(self, config)
elif isinstance(config, MagCacheConfig):
apply_mag_cache(self, config)
elif isinstance(config, TextKVCacheConfig):
apply_text_kv_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
Expand All @@ -106,12 +111,14 @@ def disable_cache(self) -> None:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_HOOK

if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
Expand All @@ -129,6 +136,8 @@ def disable_cache(self) -> None:
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TextKVCacheConfig):
registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:
Expand Down
15 changes: 11 additions & 4 deletions src/diffusers/models/transformers/transformer_nucleusmoe_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def __call__(
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
cached_txt_key: torch.FloatTensor | None = None,
cached_txt_value: torch.FloatTensor | None = None,
) -> torch.FloatTensor:
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim
Expand All @@ -287,7 +289,11 @@ def __call__(
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False)
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False)

if encoder_hidden_states is not None:
if cached_txt_key is not None and cached_txt_value is not None:
txt_key, txt_value = cached_txt_key, cached_txt_value
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)
elif encoder_hidden_states is not None:
txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))

Expand Down Expand Up @@ -537,17 +543,18 @@ def forward(
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)

context = self.encoder_proj(encoder_hidden_states)
# Skip encoder_proj when text K/V are already cached — context won't be used by the processor
attn_kwargs = attention_kwargs or {}
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)

img_normed = self.pre_attn_norm(hidden_states)
img_modulated = img_normed * scale1

attention_kwargs = attention_kwargs or {}
img_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=context,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
**attn_kwargs,
)

hidden_states = hidden_states + gate1.tanh() * img_attn_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,10 @@ def __call__(
self._num_timesteps = len(timesteps)

self.scheduler.set_begin_index(0)

if self.transformer.is_cache_enabled:
self.transformer._reset_stateful_cache()

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand Down
Loading