|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers import T5EncoderModel, T5Tokenizer |
| 6 | + |
| 7 | +from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel |
| 8 | + |
| 9 | + |
| 10 | +ckpt_id = "PixArt-alpha/PixArt-alpha" |
| 11 | +# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 |
| 12 | +interpolation_scale = {512: 1, 1024: 2} |
| 13 | + |
| 14 | + |
| 15 | +def main(args): |
| 16 | + all_state_dict = torch.load(args.orig_ckpt_path) |
| 17 | + state_dict = all_state_dict.pop("state_dict") |
| 18 | + converted_state_dict = {} |
| 19 | + |
| 20 | + # Patch embeddings. |
| 21 | + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") |
| 22 | + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") |
| 23 | + |
| 24 | + # Caption projection. |
| 25 | + converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding") |
| 26 | + converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") |
| 27 | + converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") |
| 28 | + converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") |
| 29 | + converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") |
| 30 | + |
| 31 | + # AdaLN-single LN |
| 32 | + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( |
| 33 | + "t_embedder.mlp.0.weight" |
| 34 | + ) |
| 35 | + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") |
| 36 | + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( |
| 37 | + "t_embedder.mlp.2.weight" |
| 38 | + ) |
| 39 | + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") |
| 40 | + |
| 41 | + if args.image_size == 1024: |
| 42 | + # Resolution. |
| 43 | + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( |
| 44 | + "csize_embedder.mlp.0.weight" |
| 45 | + ) |
| 46 | + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( |
| 47 | + "csize_embedder.mlp.0.bias" |
| 48 | + ) |
| 49 | + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( |
| 50 | + "csize_embedder.mlp.2.weight" |
| 51 | + ) |
| 52 | + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( |
| 53 | + "csize_embedder.mlp.2.bias" |
| 54 | + ) |
| 55 | + # Aspect ratio. |
| 56 | + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( |
| 57 | + "ar_embedder.mlp.0.weight" |
| 58 | + ) |
| 59 | + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( |
| 60 | + "ar_embedder.mlp.0.bias" |
| 61 | + ) |
| 62 | + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( |
| 63 | + "ar_embedder.mlp.2.weight" |
| 64 | + ) |
| 65 | + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( |
| 66 | + "ar_embedder.mlp.2.bias" |
| 67 | + ) |
| 68 | + # Shared norm. |
| 69 | + converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") |
| 70 | + converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") |
| 71 | + |
| 72 | + for depth in range(28): |
| 73 | + # Transformer blocks. |
| 74 | + converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( |
| 75 | + f"blocks.{depth}.scale_shift_table" |
| 76 | + ) |
| 77 | + |
| 78 | + # Attention is all you need 🤘 |
| 79 | + |
| 80 | + # Self attention. |
| 81 | + q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) |
| 82 | + q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) |
| 83 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q |
| 84 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias |
| 85 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k |
| 86 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias |
| 87 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v |
| 88 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias |
| 89 | + # Projection. |
| 90 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( |
| 91 | + f"blocks.{depth}.attn.proj.weight" |
| 92 | + ) |
| 93 | + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( |
| 94 | + f"blocks.{depth}.attn.proj.bias" |
| 95 | + ) |
| 96 | + |
| 97 | + # Feed-forward. |
| 98 | + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( |
| 99 | + f"blocks.{depth}.mlp.fc1.weight" |
| 100 | + ) |
| 101 | + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( |
| 102 | + f"blocks.{depth}.mlp.fc1.bias" |
| 103 | + ) |
| 104 | + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( |
| 105 | + f"blocks.{depth}.mlp.fc2.weight" |
| 106 | + ) |
| 107 | + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( |
| 108 | + f"blocks.{depth}.mlp.fc2.bias" |
| 109 | + ) |
| 110 | + |
| 111 | + # Cross-attention. |
| 112 | + q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") |
| 113 | + q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") |
| 114 | + k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) |
| 115 | + k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) |
| 116 | + |
| 117 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q |
| 118 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias |
| 119 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k |
| 120 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias |
| 121 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v |
| 122 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias |
| 123 | + |
| 124 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( |
| 125 | + f"blocks.{depth}.cross_attn.proj.weight" |
| 126 | + ) |
| 127 | + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( |
| 128 | + f"blocks.{depth}.cross_attn.proj.bias" |
| 129 | + ) |
| 130 | + |
| 131 | + # Final block. |
| 132 | + converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") |
| 133 | + converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") |
| 134 | + converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") |
| 135 | + |
| 136 | + # DiT XL/2 |
| 137 | + transformer = Transformer2DModel( |
| 138 | + sample_size=args.image_size // 8, |
| 139 | + num_layers=28, |
| 140 | + attention_head_dim=72, |
| 141 | + in_channels=4, |
| 142 | + out_channels=8, |
| 143 | + patch_size=2, |
| 144 | + attention_bias=True, |
| 145 | + num_attention_heads=16, |
| 146 | + cross_attention_dim=1152, |
| 147 | + activation_fn="gelu-approximate", |
| 148 | + num_embeds_ada_norm=1000, |
| 149 | + norm_type="ada_norm_single", |
| 150 | + norm_elementwise_affine=False, |
| 151 | + norm_eps=1e-6, |
| 152 | + caption_channels=4096, |
| 153 | + ) |
| 154 | + transformer.load_state_dict(converted_state_dict, strict=True) |
| 155 | + |
| 156 | + assert transformer.pos_embed.pos_embed is not None |
| 157 | + state_dict.pop("pos_embed") |
| 158 | + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" |
| 159 | + |
| 160 | + num_model_params = sum(p.numel() for p in transformer.parameters()) |
| 161 | + print(f"Total number of transformer parameters: {num_model_params}") |
| 162 | + |
| 163 | + if args.only_transformer: |
| 164 | + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) |
| 165 | + else: |
| 166 | + scheduler = DPMSolverMultistepScheduler() |
| 167 | + |
| 168 | + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema") |
| 169 | + |
| 170 | + tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") |
| 171 | + text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") |
| 172 | + |
| 173 | + pipeline = PixArtAlphaPipeline( |
| 174 | + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler |
| 175 | + ) |
| 176 | + |
| 177 | + pipeline.save_pretrained(args.dump_path) |
| 178 | + |
| 179 | + |
| 180 | +if __name__ == "__main__": |
| 181 | + parser = argparse.ArgumentParser() |
| 182 | + |
| 183 | + parser.add_argument( |
| 184 | + "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." |
| 185 | + ) |
| 186 | + parser.add_argument( |
| 187 | + "--image_size", |
| 188 | + default=1024, |
| 189 | + type=int, |
| 190 | + choices=[512, 1024], |
| 191 | + required=False, |
| 192 | + help="Image size of pretrained model, either 512 or 1024.", |
| 193 | + ) |
| 194 | + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") |
| 195 | + parser.add_argument("--only_transformer", default=True, type=bool, required=True) |
| 196 | + |
| 197 | + args = parser.parse_args() |
| 198 | + main(args) |
0 commit comments