Skip to content

Commit f6df224

Browse files
a-r-r-o-wDN6
andauthored
[feat] allow sparsectrl to be loaded from single file (huggingface#9073)
* allow sparsectrl to be loaded with single file * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 9b5180c commit f6df224

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

src/diffusers/loaders/single_file_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
"MotionAdapter": {
7676
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
7777
},
78+
"SparseControlNetModel": {
79+
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
80+
},
7881
"FluxTransformer2DModel": {
7982
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
8083
"default_subfolder": "transformer",

src/diffusers/loaders/single_file_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@
7474
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
7575
"stable_cascade_stage_c": "clip_txt_mapper.weight",
7676
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
77-
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
77+
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
7878
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
7979
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
80+
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
81+
"animatediff_rgb": "controlnet_cond_embedding.weight",
8082
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
8183
}
8284

@@ -111,6 +113,8 @@
111113
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
112114
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
113115
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
116+
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
117+
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
114118
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
115119
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
116120
}
@@ -494,7 +498,13 @@ def infer_diffusers_model_type(checkpoint):
494498
model_type = "sd3"
495499

496500
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
497-
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
501+
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
502+
model_type = "animatediff_scribble"
503+
504+
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
505+
model_type = "animatediff_rgb"
506+
507+
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
498508
model_type = "animatediff_v2"
499509

500510
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:

src/diffusers/models/controlnet_sparsectrl.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.nn import functional as F
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23+
from ..loaders import FromOriginalModelMixin
2324
from ..utils import BaseOutput, logging
2425
from .attention_processor import (
2526
ADDED_KV_ATTENTION_PROCESSORS,
@@ -92,7 +93,7 @@ def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
9293
return embedding
9394

9495

95-
class SparseControlNetModel(ModelMixin, ConfigMixin):
96+
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
9697
"""
9798
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
9899
Models](https://arxiv.org/abs/2311.16933).
@@ -314,6 +315,7 @@ def __init__(
314315
temporal_num_attention_heads=motion_num_attention_heads[i],
315316
temporal_max_seq_length=motion_max_seq_length,
316317
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
318+
temporal_double_self_attention=False,
317319
)
318320
elif down_block_type == "DownBlockMotion":
319321
down_block = DownBlockMotion(
@@ -331,6 +333,7 @@ def __init__(
331333
temporal_num_attention_heads=motion_num_attention_heads[i],
332334
temporal_max_seq_length=motion_max_seq_length,
333335
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
336+
temporal_double_self_attention=False,
334337
)
335338
else:
336339
raise ValueError(

src/diffusers/models/unets/unet_motion_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
temporal_cross_attention_dim: Optional[int] = None,
234234
temporal_max_seq_length: int = 32,
235235
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
236+
temporal_double_self_attention: bool = True,
236237
):
237238
super().__init__()
238239
resnets = []
@@ -282,6 +283,7 @@ def __init__(
282283
positional_embeddings="sinusoidal",
283284
num_positional_embeddings=temporal_max_seq_length,
284285
attention_head_dim=out_channels // temporal_num_attention_heads[i],
286+
double_self_attention=temporal_double_self_attention,
285287
)
286288
)
287289

@@ -385,6 +387,7 @@ def __init__(
385387
temporal_num_attention_heads: int = 8,
386388
temporal_max_seq_length: int = 32,
387389
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
390+
temporal_double_self_attention: bool = True,
388391
):
389392
super().__init__()
390393
resnets = []
@@ -466,6 +469,7 @@ def __init__(
466469
positional_embeddings="sinusoidal",
467470
num_positional_embeddings=temporal_max_seq_length,
468471
attention_head_dim=out_channels // temporal_num_attention_heads,
472+
double_self_attention=temporal_double_self_attention,
469473
)
470474
)
471475

0 commit comments

Comments
 (0)