Skip to content

Commit d61889f

Browse files
[Feat] PixArt-Alpha (huggingface#5642)
* init pixart alpha pipeline * fix: import * script * script * script * add: vae to the pipeline * add: vae_scale_factor * add: checkpoint_path * clean conversion script a bit. * size embeddings. * fix: size embedding * update scrip * support for interpolation of position embedding. * support for conditioning. * .. * .. * .. * final layer * final layer * align if encode_prompt * support for caption embedding * refactor * refactor * refactor * start cross attention * start cross attention * cross_attention_dim * cross * cross * support for resolution and aspect_ratio * support for caption projection * refactor patch embeddings * batch_size * up * commit * commit * commit. * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze. * squeeze. * fix final block./ * fix final block./ * fix final block./ * clean * fix: interpolation scale. * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging * debugging * debugging * debugging * debugging * debugging * debugging * make --checkpoint_path non-required. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * remove num_tokens * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * debug * debug * update conversion script. * update conversion script. * update conversion script. * debug * debug * debug * clean * debug * debug * debug * debug * debug * debug * debug * debug * deug * debug * debug * debug * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * clean * fix * fix * boom * boom * some changes * boom * save * up * remove i * fix more tests * DPMSolverMultistepScheduler * fix * offloading * fix conversion script * fix conversion script * remove print * remove support for negative prompt embeds. * typo. * remove extra kwargs * bring conversion script to where it was * fix * trying mu luck * trying my luck again * again * again * again * clean up * up * up * update example * support for 512 * remove spacing * finalize docs. * test debug * fix: assertion values. * debug * debug * debug * fix: repeat * remove prints. * Apply suggestions from code review * Apply suggestions from code review * Correct more * Apply suggestions from code review * Change all * Clean more * fix more * Fix more * Fix more * Correct more * address patrick's comments. * remove unneeded args * clean up pipeline. * sty;e * make the use of additional conditions better conditioned. * None better * dtype * height and width validation * add a note about size brackets. * fix * spit out slow test outputs. * fix? * fix optional test * fix more * remove unneeded comment * debug --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 2b23ec8 commit d61889f

15 files changed

+1501
-30
lines changed

Diff for: docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@
268268
title: Parallel Sampling of Diffusion Models
269269
- local: api/pipelines/pix2pix_zero
270270
title: Pix2Pix Zero
271+
- local: api/pipelines/pixart
272+
title: PixArt
271273
- local: api/pipelines/pndm
272274
title: PNDM
273275
- local: api/pipelines/repaint

Diff for: docs/source/en/api/pipelines/pixart.md

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# PixArt
14+
15+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png)
16+
17+
[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.
18+
19+
The abstract from the paper is:
20+
21+
*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.*
22+
23+
You can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).
24+
25+
Some notes about this pipeline:
26+
27+
* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit.md).
28+
* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.
29+
* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).
30+
* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.
31+
32+
## PixArtAlphaPipeline
33+
34+
[[autodoc]] PixArtAlphaPipeline
35+
- all
36+
- __call__

Diff for: scripts/convert_pixart_alpha_to_diffusers.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from transformers import T5EncoderModel, T5Tokenizer
6+
7+
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel
8+
9+
10+
ckpt_id = "PixArt-alpha/PixArt-alpha"
11+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
12+
interpolation_scale = {512: 1, 1024: 2}
13+
14+
15+
def main(args):
16+
all_state_dict = torch.load(args.orig_ckpt_path)
17+
state_dict = all_state_dict.pop("state_dict")
18+
converted_state_dict = {}
19+
20+
# Patch embeddings.
21+
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
22+
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
23+
24+
# Caption projection.
25+
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
26+
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
27+
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
28+
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
29+
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
30+
31+
# AdaLN-single LN
32+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
33+
"t_embedder.mlp.0.weight"
34+
)
35+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
36+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
37+
"t_embedder.mlp.2.weight"
38+
)
39+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
40+
41+
if args.image_size == 1024:
42+
# Resolution.
43+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
44+
"csize_embedder.mlp.0.weight"
45+
)
46+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
47+
"csize_embedder.mlp.0.bias"
48+
)
49+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
50+
"csize_embedder.mlp.2.weight"
51+
)
52+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
53+
"csize_embedder.mlp.2.bias"
54+
)
55+
# Aspect ratio.
56+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
57+
"ar_embedder.mlp.0.weight"
58+
)
59+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
60+
"ar_embedder.mlp.0.bias"
61+
)
62+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
63+
"ar_embedder.mlp.2.weight"
64+
)
65+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
66+
"ar_embedder.mlp.2.bias"
67+
)
68+
# Shared norm.
69+
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
70+
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
71+
72+
for depth in range(28):
73+
# Transformer blocks.
74+
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
75+
f"blocks.{depth}.scale_shift_table"
76+
)
77+
78+
# Attention is all you need 🤘
79+
80+
# Self attention.
81+
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
82+
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
83+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
84+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
85+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
86+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
87+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
88+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
89+
# Projection.
90+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
91+
f"blocks.{depth}.attn.proj.weight"
92+
)
93+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
94+
f"blocks.{depth}.attn.proj.bias"
95+
)
96+
97+
# Feed-forward.
98+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
99+
f"blocks.{depth}.mlp.fc1.weight"
100+
)
101+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
102+
f"blocks.{depth}.mlp.fc1.bias"
103+
)
104+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
105+
f"blocks.{depth}.mlp.fc2.weight"
106+
)
107+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
108+
f"blocks.{depth}.mlp.fc2.bias"
109+
)
110+
111+
# Cross-attention.
112+
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
113+
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
114+
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
115+
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
116+
117+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
118+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
119+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
120+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
121+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
122+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
123+
124+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
125+
f"blocks.{depth}.cross_attn.proj.weight"
126+
)
127+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
128+
f"blocks.{depth}.cross_attn.proj.bias"
129+
)
130+
131+
# Final block.
132+
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
133+
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
134+
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
135+
136+
# DiT XL/2
137+
transformer = Transformer2DModel(
138+
sample_size=args.image_size // 8,
139+
num_layers=28,
140+
attention_head_dim=72,
141+
in_channels=4,
142+
out_channels=8,
143+
patch_size=2,
144+
attention_bias=True,
145+
num_attention_heads=16,
146+
cross_attention_dim=1152,
147+
activation_fn="gelu-approximate",
148+
num_embeds_ada_norm=1000,
149+
norm_type="ada_norm_single",
150+
norm_elementwise_affine=False,
151+
norm_eps=1e-6,
152+
caption_channels=4096,
153+
)
154+
transformer.load_state_dict(converted_state_dict, strict=True)
155+
156+
assert transformer.pos_embed.pos_embed is not None
157+
state_dict.pop("pos_embed")
158+
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
159+
160+
num_model_params = sum(p.numel() for p in transformer.parameters())
161+
print(f"Total number of transformer parameters: {num_model_params}")
162+
163+
if args.only_transformer:
164+
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
165+
else:
166+
scheduler = DPMSolverMultistepScheduler()
167+
168+
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema")
169+
170+
tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
171+
text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
172+
173+
pipeline = PixArtAlphaPipeline(
174+
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
175+
)
176+
177+
pipeline.save_pretrained(args.dump_path)
178+
179+
180+
if __name__ == "__main__":
181+
parser = argparse.ArgumentParser()
182+
183+
parser.add_argument(
184+
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
185+
)
186+
parser.add_argument(
187+
"--image_size",
188+
default=1024,
189+
type=int,
190+
choices=[512, 1024],
191+
required=False,
192+
help="Image size of pretrained model, either 512 or 1024.",
193+
)
194+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
195+
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
196+
197+
args = parser.parse_args()
198+
main(args)

Diff for: src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@
235235
"LDMTextToImagePipeline",
236236
"MusicLDMPipeline",
237237
"PaintByExamplePipeline",
238+
"PixArtAlphaPipeline",
238239
"SemanticStableDiffusionPipeline",
239240
"ShapEImg2ImgPipeline",
240241
"ShapEPipeline",
@@ -579,6 +580,7 @@
579580
LDMTextToImagePipeline,
580581
MusicLDMPipeline,
581582
PaintByExamplePipeline,
583+
PixArtAlphaPipeline,
582584
SemanticStableDiffusionPipeline,
583585
ShapEImg2ImgPipeline,
584586
ShapEPipeline,

0 commit comments

Comments
 (0)