|
74 | 74 | "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
75 | 75 | "stable_cascade_stage_c": "clip_txt_mapper.weight",
|
76 | 76 | "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
77 |
| - "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe", |
| 77 | + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", |
78 | 78 | "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
| 80 | + "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", |
| 81 | + "animatediff_rgb": "controlnet_cond_embedding.weight", |
80 | 82 | "flux": "double_blocks.0.img_attn.norm.key_norm.scale",
|
81 | 83 | }
|
82 | 84 |
|
|
111 | 113 | "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
112 | 114 | "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
113 | 115 | "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
| 116 | + "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, |
| 117 | + "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, |
114 | 118 | "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
115 | 119 | "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
116 | 120 | }
|
@@ -494,7 +498,13 @@ def infer_diffusers_model_type(checkpoint):
|
494 | 498 | model_type = "sd3"
|
495 | 499 |
|
496 | 500 | elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
497 |
| - if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: |
| 501 | + if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: |
| 502 | + model_type = "animatediff_scribble" |
| 503 | + |
| 504 | + elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: |
| 505 | + model_type = "animatediff_rgb" |
| 506 | + |
| 507 | + elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: |
498 | 508 | model_type = "animatediff_v2"
|
499 | 509 |
|
500 | 510 | elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
|
|
0 commit comments