Skip to content

Commit e1b603d

Browse files
DN6sayakpaul
andauthored
[Single File] Add single file support for Flux Transformer (huggingface#9083)
* update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent e432560 commit e1b603d

File tree

4 files changed

+260
-2
lines changed

4 files changed

+260
-2
lines changed

docs/source/en/api/pipelines/flux.md

+53
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,59 @@ out = pipe(
7777
out.save("image.png")
7878
```
7979

80+
## Single File Loading for the `FluxTransformer2DModel`
81+
82+
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
83+
84+
<Tip>
85+
`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
86+
</Tip>
87+
88+
The following example demonstrates how to run Flux with less than 16GB of VRAM.
89+
90+
First install `optimum-quanto`
91+
92+
```shell
93+
pip install optimum-quanto
94+
```
95+
96+
Then run the following example
97+
98+
```python
99+
import torch
100+
from diffusers import FluxTransformer2DModel, FluxPipeline
101+
from transformers import T5EncoderModel, CLIPTextModel
102+
from optimum.quanto import freeze, qfloat8, quantize
103+
104+
bfl_repo = "black-forest-labs/FLUX.1-dev"
105+
dtype = torch.bfloat16
106+
107+
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
108+
quantize(transformer, weights=qfloat8)
109+
freeze(transformer)
110+
111+
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
112+
quantize(text_encoder_2, weights=qfloat8)
113+
freeze(text_encoder_2)
114+
115+
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
116+
pipe.transformer = transformer
117+
pipe.text_encoder_2 = text_encoder_2
118+
119+
pipe.enable_model_cpu_offload()
120+
121+
prompt = "A cat holding a sign that says hello world"
122+
image = pipe(
123+
prompt,
124+
guidance_scale=3.5,
125+
output_type="pil",
126+
num_inference_steps=20,
127+
generator=torch.Generator("cpu").manual_seed(0)
128+
).images[0]
129+
130+
image.save("flux-fp8-dev.png")
131+
```
132+
80133
## FluxPipeline
81134

82135
[[autodoc]] FluxPipeline

src/diffusers/loaders/single_file_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SingleFileComponentError,
2525
convert_animatediff_checkpoint_to_diffusers,
2626
convert_controlnet_checkpoint,
27+
convert_flux_transformer_checkpoint_to_diffusers,
2728
convert_ldm_unet_checkpoint,
2829
convert_ldm_vae_checkpoint,
2930
convert_sd3_transformer_checkpoint_to_diffusers,
@@ -74,6 +75,10 @@
7475
"MotionAdapter": {
7576
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
7677
},
78+
"FluxTransformer2DModel": {
79+
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
80+
"default_subfolder": "transformer",
81+
},
7782
}
7883

7984

src/diffusers/loaders/single_file_utils.py

+200
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
7878
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
7979
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
80+
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
8081
}
8182

8283
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -110,6 +111,8 @@
110111
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
111112
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
112113
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
114+
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
115+
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
113116
}
114117

115118
# Use to configure model sample size when original config is provided
@@ -503,6 +506,11 @@ def infer_diffusers_model_type(checkpoint):
503506
else:
504507
model_type = "animatediff_v3"
505508

509+
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
510+
if "guidance_in.in_layer.bias" in checkpoint:
511+
model_type = "flux-dev"
512+
else:
513+
model_type = "flux-schnell"
506514
else:
507515
model_type = "v1"
508516

@@ -1859,3 +1867,195 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
18591867
] = v
18601868

18611869
return converted_state_dict
1870+
1871+
1872+
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1873+
converted_state_dict = {}
1874+
1875+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1876+
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1877+
mlp_ratio = 4.0
1878+
inner_dim = 3072
1879+
1880+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1881+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1882+
def swap_scale_shift(weight):
1883+
shift, scale = weight.chunk(2, dim=0)
1884+
new_weight = torch.cat([scale, shift], dim=0)
1885+
return new_weight
1886+
1887+
## time_text_embed.timestep_embedder <- time_in
1888+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1889+
"time_in.in_layer.weight"
1890+
)
1891+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
1892+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1893+
"time_in.out_layer.weight"
1894+
)
1895+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
1896+
1897+
## time_text_embed.text_embedder <- vector_in
1898+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
1899+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
1900+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
1901+
"vector_in.out_layer.weight"
1902+
)
1903+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
1904+
1905+
# guidance
1906+
has_guidance = any("guidance" in k for k in checkpoint)
1907+
if has_guidance:
1908+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
1909+
"guidance_in.in_layer.weight"
1910+
)
1911+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
1912+
"guidance_in.in_layer.bias"
1913+
)
1914+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
1915+
"guidance_in.out_layer.weight"
1916+
)
1917+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
1918+
"guidance_in.out_layer.bias"
1919+
)
1920+
1921+
# context_embedder
1922+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
1923+
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
1924+
1925+
# x_embedder
1926+
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
1927+
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
1928+
1929+
# double transformer blocks
1930+
for i in range(num_layers):
1931+
block_prefix = f"transformer_blocks.{i}."
1932+
# norms.
1933+
## norm1
1934+
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
1935+
f"double_blocks.{i}.img_mod.lin.weight"
1936+
)
1937+
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
1938+
f"double_blocks.{i}.img_mod.lin.bias"
1939+
)
1940+
## norm1_context
1941+
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
1942+
f"double_blocks.{i}.txt_mod.lin.weight"
1943+
)
1944+
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
1945+
f"double_blocks.{i}.txt_mod.lin.bias"
1946+
)
1947+
# Q, K, V
1948+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
1949+
context_q, context_k, context_v = torch.chunk(
1950+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
1951+
)
1952+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1953+
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
1954+
)
1955+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1956+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
1957+
)
1958+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
1959+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
1960+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
1961+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
1962+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
1963+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
1964+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
1965+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1966+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
1967+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1968+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
1969+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1970+
# qk_norm
1971+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
1972+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
1973+
)
1974+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
1975+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
1976+
)
1977+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
1978+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
1979+
)
1980+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
1981+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
1982+
)
1983+
# ff img_mlp
1984+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
1985+
f"double_blocks.{i}.img_mlp.0.weight"
1986+
)
1987+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
1988+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
1989+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
1990+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
1991+
f"double_blocks.{i}.txt_mlp.0.weight"
1992+
)
1993+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
1994+
f"double_blocks.{i}.txt_mlp.0.bias"
1995+
)
1996+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
1997+
f"double_blocks.{i}.txt_mlp.2.weight"
1998+
)
1999+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
2000+
f"double_blocks.{i}.txt_mlp.2.bias"
2001+
)
2002+
# output projections.
2003+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
2004+
f"double_blocks.{i}.img_attn.proj.weight"
2005+
)
2006+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
2007+
f"double_blocks.{i}.img_attn.proj.bias"
2008+
)
2009+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
2010+
f"double_blocks.{i}.txt_attn.proj.weight"
2011+
)
2012+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
2013+
f"double_blocks.{i}.txt_attn.proj.bias"
2014+
)
2015+
2016+
# single transfomer blocks
2017+
for i in range(num_single_layers):
2018+
block_prefix = f"single_transformer_blocks.{i}."
2019+
# norm.linear <- single_blocks.0.modulation.lin
2020+
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2021+
f"single_blocks.{i}.modulation.lin.weight"
2022+
)
2023+
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2024+
f"single_blocks.{i}.modulation.lin.bias"
2025+
)
2026+
# Q, K, V, mlp
2027+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
2028+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
2029+
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
2030+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
2031+
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
2032+
)
2033+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
2034+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
2035+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
2036+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
2037+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
2038+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
2039+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
2040+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
2041+
# qk norm
2042+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2043+
f"single_blocks.{i}.norm.query_norm.scale"
2044+
)
2045+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2046+
f"single_blocks.{i}.norm.key_norm.scale"
2047+
)
2048+
# output projections.
2049+
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
2050+
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
2051+
2052+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2053+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2054+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2055+
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2056+
)
2057+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2058+
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2059+
)
2060+
2061+
return converted_state_dict

src/diffusers/models/transformers/transformer_flux.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23-
from ...loaders import PeftAdapterMixin
23+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2424
from ...models.attention import FeedForward
2525
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
2626
from ...models.modeling_utils import ModelMixin
@@ -227,7 +227,7 @@ def forward(
227227
return encoder_hidden_states, hidden_states
228228

229229

230-
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
230+
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
231231
"""
232232
The Transformer model introduced in Flux.
233233

0 commit comments

Comments
 (0)