|
77 | 77 | "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
78 | 78 | "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
| 80 | + "flux": "double_blocks.0.img_attn.norm.key_norm.scale", |
80 | 81 | }
|
81 | 82 |
|
82 | 83 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
110 | 111 | "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
111 | 112 | "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
112 | 113 | "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
| 114 | + "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, |
| 115 | + "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, |
113 | 116 | }
|
114 | 117 |
|
115 | 118 | # Use to configure model sample size when original config is provided
|
@@ -503,6 +506,11 @@ def infer_diffusers_model_type(checkpoint):
|
503 | 506 | else:
|
504 | 507 | model_type = "animatediff_v3"
|
505 | 508 |
|
| 509 | + elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: |
| 510 | + if "guidance_in.in_layer.bias" in checkpoint: |
| 511 | + model_type = "flux-dev" |
| 512 | + else: |
| 513 | + model_type = "flux-schnell" |
506 | 514 | else:
|
507 | 515 | model_type = "v1"
|
508 | 516 |
|
@@ -1859,3 +1867,195 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1859 | 1867 | ] = v
|
1860 | 1868 |
|
1861 | 1869 | return converted_state_dict
|
| 1870 | + |
| 1871 | + |
| 1872 | +def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 1873 | + converted_state_dict = {} |
| 1874 | + |
| 1875 | + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 |
| 1876 | + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 |
| 1877 | + mlp_ratio = 4.0 |
| 1878 | + inner_dim = 3072 |
| 1879 | + |
| 1880 | + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; |
| 1881 | + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation |
| 1882 | + def swap_scale_shift(weight): |
| 1883 | + shift, scale = weight.chunk(2, dim=0) |
| 1884 | + new_weight = torch.cat([scale, shift], dim=0) |
| 1885 | + return new_weight |
| 1886 | + |
| 1887 | + ## time_text_embed.timestep_embedder <- time_in |
| 1888 | + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( |
| 1889 | + "time_in.in_layer.weight" |
| 1890 | + ) |
| 1891 | + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") |
| 1892 | + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( |
| 1893 | + "time_in.out_layer.weight" |
| 1894 | + ) |
| 1895 | + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") |
| 1896 | + |
| 1897 | + ## time_text_embed.text_embedder <- vector_in |
| 1898 | + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") |
| 1899 | + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") |
| 1900 | + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( |
| 1901 | + "vector_in.out_layer.weight" |
| 1902 | + ) |
| 1903 | + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") |
| 1904 | + |
| 1905 | + # guidance |
| 1906 | + has_guidance = any("guidance" in k for k in checkpoint) |
| 1907 | + if has_guidance: |
| 1908 | + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( |
| 1909 | + "guidance_in.in_layer.weight" |
| 1910 | + ) |
| 1911 | + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( |
| 1912 | + "guidance_in.in_layer.bias" |
| 1913 | + ) |
| 1914 | + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( |
| 1915 | + "guidance_in.out_layer.weight" |
| 1916 | + ) |
| 1917 | + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( |
| 1918 | + "guidance_in.out_layer.bias" |
| 1919 | + ) |
| 1920 | + |
| 1921 | + # context_embedder |
| 1922 | + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") |
| 1923 | + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") |
| 1924 | + |
| 1925 | + # x_embedder |
| 1926 | + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") |
| 1927 | + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") |
| 1928 | + |
| 1929 | + # double transformer blocks |
| 1930 | + for i in range(num_layers): |
| 1931 | + block_prefix = f"transformer_blocks.{i}." |
| 1932 | + # norms. |
| 1933 | + ## norm1 |
| 1934 | + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( |
| 1935 | + f"double_blocks.{i}.img_mod.lin.weight" |
| 1936 | + ) |
| 1937 | + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( |
| 1938 | + f"double_blocks.{i}.img_mod.lin.bias" |
| 1939 | + ) |
| 1940 | + ## norm1_context |
| 1941 | + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( |
| 1942 | + f"double_blocks.{i}.txt_mod.lin.weight" |
| 1943 | + ) |
| 1944 | + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( |
| 1945 | + f"double_blocks.{i}.txt_mod.lin.bias" |
| 1946 | + ) |
| 1947 | + # Q, K, V |
| 1948 | + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) |
| 1949 | + context_q, context_k, context_v = torch.chunk( |
| 1950 | + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 |
| 1951 | + ) |
| 1952 | + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( |
| 1953 | + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 |
| 1954 | + ) |
| 1955 | + context_q_bias, context_k_bias, context_v_bias = torch.chunk( |
| 1956 | + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 |
| 1957 | + ) |
| 1958 | + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) |
| 1959 | + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) |
| 1960 | + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) |
| 1961 | + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) |
| 1962 | + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) |
| 1963 | + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) |
| 1964 | + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) |
| 1965 | + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) |
| 1966 | + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) |
| 1967 | + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) |
| 1968 | + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) |
| 1969 | + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) |
| 1970 | + # qk_norm |
| 1971 | + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( |
| 1972 | + f"double_blocks.{i}.img_attn.norm.query_norm.scale" |
| 1973 | + ) |
| 1974 | + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( |
| 1975 | + f"double_blocks.{i}.img_attn.norm.key_norm.scale" |
| 1976 | + ) |
| 1977 | + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( |
| 1978 | + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" |
| 1979 | + ) |
| 1980 | + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( |
| 1981 | + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" |
| 1982 | + ) |
| 1983 | + # ff img_mlp |
| 1984 | + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( |
| 1985 | + f"double_blocks.{i}.img_mlp.0.weight" |
| 1986 | + ) |
| 1987 | + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") |
| 1988 | + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") |
| 1989 | + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") |
| 1990 | + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( |
| 1991 | + f"double_blocks.{i}.txt_mlp.0.weight" |
| 1992 | + ) |
| 1993 | + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( |
| 1994 | + f"double_blocks.{i}.txt_mlp.0.bias" |
| 1995 | + ) |
| 1996 | + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( |
| 1997 | + f"double_blocks.{i}.txt_mlp.2.weight" |
| 1998 | + ) |
| 1999 | + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( |
| 2000 | + f"double_blocks.{i}.txt_mlp.2.bias" |
| 2001 | + ) |
| 2002 | + # output projections. |
| 2003 | + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( |
| 2004 | + f"double_blocks.{i}.img_attn.proj.weight" |
| 2005 | + ) |
| 2006 | + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( |
| 2007 | + f"double_blocks.{i}.img_attn.proj.bias" |
| 2008 | + ) |
| 2009 | + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( |
| 2010 | + f"double_blocks.{i}.txt_attn.proj.weight" |
| 2011 | + ) |
| 2012 | + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( |
| 2013 | + f"double_blocks.{i}.txt_attn.proj.bias" |
| 2014 | + ) |
| 2015 | + |
| 2016 | + # single transfomer blocks |
| 2017 | + for i in range(num_single_layers): |
| 2018 | + block_prefix = f"single_transformer_blocks.{i}." |
| 2019 | + # norm.linear <- single_blocks.0.modulation.lin |
| 2020 | + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( |
| 2021 | + f"single_blocks.{i}.modulation.lin.weight" |
| 2022 | + ) |
| 2023 | + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( |
| 2024 | + f"single_blocks.{i}.modulation.lin.bias" |
| 2025 | + ) |
| 2026 | + # Q, K, V, mlp |
| 2027 | + mlp_hidden_dim = int(inner_dim * mlp_ratio) |
| 2028 | + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) |
| 2029 | + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) |
| 2030 | + q_bias, k_bias, v_bias, mlp_bias = torch.split( |
| 2031 | + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 |
| 2032 | + ) |
| 2033 | + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) |
| 2034 | + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) |
| 2035 | + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) |
| 2036 | + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) |
| 2037 | + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) |
| 2038 | + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) |
| 2039 | + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) |
| 2040 | + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) |
| 2041 | + # qk norm |
| 2042 | + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( |
| 2043 | + f"single_blocks.{i}.norm.query_norm.scale" |
| 2044 | + ) |
| 2045 | + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( |
| 2046 | + f"single_blocks.{i}.norm.key_norm.scale" |
| 2047 | + ) |
| 2048 | + # output projections. |
| 2049 | + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") |
| 2050 | + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") |
| 2051 | + |
| 2052 | + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") |
| 2053 | + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") |
| 2054 | + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( |
| 2055 | + checkpoint.pop("final_layer.adaLN_modulation.1.weight") |
| 2056 | + ) |
| 2057 | + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( |
| 2058 | + checkpoint.pop("final_layer.adaLN_modulation.1.bias") |
| 2059 | + ) |
| 2060 | + |
| 2061 | + return converted_state_dict |
0 commit comments