<a href="https://colab.research.google.com/github/wayneotemah/AI-and-ML/blob/main/stable_diffutson.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow_addons
!pip install ftfy

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.8/611.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow_addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow_addons
Successfully installed tensorflow_addons-0.23.0 typeguard-2.13.3
Collecting ftfy
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.3


In [32]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import numpy as np
import math
from tqdm import tqdm
from PIL import Image
from PIL.PngImagePlugin import PngInfo

# CONSTANTS

In [3]:
PYTORCH_CKPT_MAPPING = {'text_encoder': [('cond_stage_model.transformer.text_model.embeddings.token_embedding.weight',
   None),
  ('cond_stage_model.transformer.text_model.embeddings.position_embedding.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias',
   None),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight',
   (1, 0)),
  ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias',
   None),
  ('cond_stage_model.transformer.text_model.final_layer_norm.weight', None),
  ('cond_stage_model.transformer.text_model.final_layer_norm.bias', None)],
 'diffusion_model': [('model.diffusion_model.time_embed.0.weight', (1, 0)),
  ('model.diffusion_model.time_embed.0.bias', None),
  ('model.diffusion_model.time_embed.2.weight', (1, 0)),
  ('model.diffusion_model.time_embed.2.bias', None),
  ('model.diffusion_model.input_blocks.0.0.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.0.0.bias', None),
  ('model.diffusion_model.input_blocks.1.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.1.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.1.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.1.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.1.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.1.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.1.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.1.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.1.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.1.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.1.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.1.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.1.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.1.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.1.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.1.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.2.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.2.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.2.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.2.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.2.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.2.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.2.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.2.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.2.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.2.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.2.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.2.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.2.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.2.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.2.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.2.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.3.0.op.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.3.0.op.bias', None),
  ('model.diffusion_model.input_blocks.4.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.4.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.4.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.4.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.4.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.4.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.4.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.4.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.4.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.4.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.4.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.4.0.skip_connection.bias', None),
  ('model.diffusion_model.input_blocks.4.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.4.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.4.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.4.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.4.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.4.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.5.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.5.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.5.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.5.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.5.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.5.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.5.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.5.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.5.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.5.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.5.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.5.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.5.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.5.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.5.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.5.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.6.0.op.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.6.0.op.bias', None),
  ('model.diffusion_model.input_blocks.7.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.7.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.7.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.7.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.7.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.7.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.7.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.7.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.7.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.7.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.7.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.7.0.skip_connection.bias', None),
  ('model.diffusion_model.input_blocks.7.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.7.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.7.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.7.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.7.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.7.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.8.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.8.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.8.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.8.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.8.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.8.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.8.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.8.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.8.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.8.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.8.1.norm.weight', None),
  ('model.diffusion_model.input_blocks.8.1.norm.bias', None),
  ('model.diffusion_model.input_blocks.8.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.8.1.proj_in.bias', None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.input_blocks.8.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.8.1.proj_out.bias', None),
  ('model.diffusion_model.input_blocks.9.0.op.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.9.0.op.bias', None),
  ('model.diffusion_model.input_blocks.10.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.10.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.10.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.10.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.10.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.10.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.10.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.10.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.10.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.10.0.out_layers.3.bias', None),
  ('model.diffusion_model.input_blocks.11.0.in_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.11.0.in_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.11.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.11.0.in_layers.2.bias', None),
  ('model.diffusion_model.input_blocks.11.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.input_blocks.11.0.emb_layers.1.bias', None),
  ('model.diffusion_model.input_blocks.11.0.out_layers.0.weight', None),
  ('model.diffusion_model.input_blocks.11.0.out_layers.0.bias', None),
  ('model.diffusion_model.input_blocks.11.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.input_blocks.11.0.out_layers.3.bias', None),
  ('model.diffusion_model.middle_block.0.in_layers.0.weight', None),
  ('model.diffusion_model.middle_block.0.in_layers.0.bias', None),
  ('model.diffusion_model.middle_block.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.0.in_layers.2.bias', None),
  ('model.diffusion_model.middle_block.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.middle_block.0.emb_layers.1.bias', None),
  ('model.diffusion_model.middle_block.0.out_layers.0.weight', None),
  ('model.diffusion_model.middle_block.0.out_layers.0.bias', None),
  ('model.diffusion_model.middle_block.0.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.0.out_layers.3.bias', None),
  ('model.diffusion_model.middle_block.1.norm.weight', None),
  ('model.diffusion_model.middle_block.1.norm.bias', None),
  ('model.diffusion_model.middle_block.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.1.proj_in.bias', None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.middle_block.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.1.proj_out.bias', None),
  ('model.diffusion_model.middle_block.2.in_layers.0.weight', None),
  ('model.diffusion_model.middle_block.2.in_layers.0.bias', None),
  ('model.diffusion_model.middle_block.2.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.2.in_layers.2.bias', None),
  ('model.diffusion_model.middle_block.2.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.middle_block.2.emb_layers.1.bias', None),
  ('model.diffusion_model.middle_block.2.out_layers.0.weight', None),
  ('model.diffusion_model.middle_block.2.out_layers.0.bias', None),
  ('model.diffusion_model.middle_block.2.out_layers.3.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.middle_block.2.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.0.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.0.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.0.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.0.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.0.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.0.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.0.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.0.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.0.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.0.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.0.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.0.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.1.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.1.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.1.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.1.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.1.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.1.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.1.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.1.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.1.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.1.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.1.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.1.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.2.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.2.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.2.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.2.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.2.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.2.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.2.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.2.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.2.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.2.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.2.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.2.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.2.1.conv.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.2.1.conv.bias', None),
  ('model.diffusion_model.output_blocks.3.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.3.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.3.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.3.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.3.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.3.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.3.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.3.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.3.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.3.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.3.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.3.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.3.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.3.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.3.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.3.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.3.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.3.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.4.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.4.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.4.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.4.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.4.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.4.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.4.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.4.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.4.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.4.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.4.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.4.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.4.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.4.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.4.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.4.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.4.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.4.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.5.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.5.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.5.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.5.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.5.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.5.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.5.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.5.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.5.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.5.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.5.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.5.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.5.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.5.2.conv.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.5.2.conv.bias', None),
  ('model.diffusion_model.output_blocks.6.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.6.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.6.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.6.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.6.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.6.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.6.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.6.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.6.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.6.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.6.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.6.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.6.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.6.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.6.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.6.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.6.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.6.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.7.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.7.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.7.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.7.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.7.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.7.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.7.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.7.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.7.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.7.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.7.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.7.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.7.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.7.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.7.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.7.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.7.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.7.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.8.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.8.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.8.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.8.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.8.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.8.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.8.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.8.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.8.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.8.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.8.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.8.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.8.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.8.2.conv.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.8.2.conv.bias', None),
  ('model.diffusion_model.output_blocks.9.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.9.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.9.0.in_layers.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.9.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.9.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.9.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.9.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.9.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.9.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.9.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.9.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.9.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.9.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.9.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.9.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.9.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.9.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.9.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.10.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.10.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.10.0.in_layers.2.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.10.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.10.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.10.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.10.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.10.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.10.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.10.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.10.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.10.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.10.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.10.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.10.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.10.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.10.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.10.1.proj_out.bias', None),
  ('model.diffusion_model.output_blocks.11.0.in_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.11.0.in_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.11.0.in_layers.2.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.11.0.in_layers.2.bias', None),
  ('model.diffusion_model.output_blocks.11.0.emb_layers.1.weight', (1, 0)),
  ('model.diffusion_model.output_blocks.11.0.emb_layers.1.bias', None),
  ('model.diffusion_model.output_blocks.11.0.out_layers.0.weight', None),
  ('model.diffusion_model.output_blocks.11.0.out_layers.0.bias', None),
  ('model.diffusion_model.output_blocks.11.0.out_layers.3.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.11.0.out_layers.3.bias', None),
  ('model.diffusion_model.output_blocks.11.0.skip_connection.weight',
   (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.11.0.skip_connection.bias', None),
  ('model.diffusion_model.output_blocks.11.1.norm.weight', None),
  ('model.diffusion_model.output_blocks.11.1.norm.bias', None),
  ('model.diffusion_model.output_blocks.11.1.proj_in.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.11.1.proj_in.bias', None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight',
   (1, 0)),
  ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias',
   None),
  ('model.diffusion_model.output_blocks.11.1.proj_out.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.output_blocks.11.1.proj_out.bias', None),
  ('model.diffusion_model.out.0.weight', None),
  ('model.diffusion_model.out.0.bias', None),
  ('model.diffusion_model.out.2.weight', (2, 3, 1, 0)),
  ('model.diffusion_model.out.2.bias', None)],
 'decoder': [('first_stage_model.post_quant_conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.post_quant_conv.bias', None),
  ('first_stage_model.decoder.conv_in.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.conv_in.bias', None),
  ('first_stage_model.decoder.mid.block_1.norm1.weight', None),
  ('first_stage_model.decoder.mid.block_1.norm1.bias', None),
  ('first_stage_model.decoder.mid.block_1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.block_1.conv1.bias', None),
  ('first_stage_model.decoder.mid.block_1.norm2.weight', None),
  ('first_stage_model.decoder.mid.block_1.norm2.bias', None),
  ('first_stage_model.decoder.mid.block_1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.block_1.conv2.bias', None),
  ('first_stage_model.decoder.mid.attn_1.norm.weight', None),
  ('first_stage_model.decoder.mid.attn_1.norm.bias', None),
  ('first_stage_model.decoder.mid.attn_1.q.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.attn_1.q.bias', None),
  ('first_stage_model.decoder.mid.attn_1.k.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.attn_1.k.bias', None),
  ('first_stage_model.decoder.mid.attn_1.v.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.attn_1.v.bias', None),
  ('first_stage_model.decoder.mid.attn_1.proj_out.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.attn_1.proj_out.bias', None),
  ('first_stage_model.decoder.mid.block_2.norm1.weight', None),
  ('first_stage_model.decoder.mid.block_2.norm1.bias', None),
  ('first_stage_model.decoder.mid.block_2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.block_2.conv1.bias', None),
  ('first_stage_model.decoder.mid.block_2.norm2.weight', None),
  ('first_stage_model.decoder.mid.block_2.norm2.bias', None),
  ('first_stage_model.decoder.mid.block_2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.mid.block_2.conv2.bias', None),
  ('first_stage_model.decoder.up.3.block.0.norm1.weight', None),
  ('first_stage_model.decoder.up.3.block.0.norm1.bias', None),
  ('first_stage_model.decoder.up.3.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.0.conv1.bias', None),
  ('first_stage_model.decoder.up.3.block.0.norm2.weight', None),
  ('first_stage_model.decoder.up.3.block.0.norm2.bias', None),
  ('first_stage_model.decoder.up.3.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.0.conv2.bias', None),
  ('first_stage_model.decoder.up.3.block.1.norm1.weight', None),
  ('first_stage_model.decoder.up.3.block.1.norm1.bias', None),
  ('first_stage_model.decoder.up.3.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.1.conv1.bias', None),
  ('first_stage_model.decoder.up.3.block.1.norm2.weight', None),
  ('first_stage_model.decoder.up.3.block.1.norm2.bias', None),
  ('first_stage_model.decoder.up.3.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.1.conv2.bias', None),
  ('first_stage_model.decoder.up.3.block.2.norm1.weight', None),
  ('first_stage_model.decoder.up.3.block.2.norm1.bias', None),
  ('first_stage_model.decoder.up.3.block.2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.2.conv1.bias', None),
  ('first_stage_model.decoder.up.3.block.2.norm2.weight', None),
  ('first_stage_model.decoder.up.3.block.2.norm2.bias', None),
  ('first_stage_model.decoder.up.3.block.2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.block.2.conv2.bias', None),
  ('first_stage_model.decoder.up.3.upsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.3.upsample.conv.bias', None),
  ('first_stage_model.decoder.up.2.block.0.norm1.weight', None),
  ('first_stage_model.decoder.up.2.block.0.norm1.bias', None),
  ('first_stage_model.decoder.up.2.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.0.conv1.bias', None),
  ('first_stage_model.decoder.up.2.block.0.norm2.weight', None),
  ('first_stage_model.decoder.up.2.block.0.norm2.bias', None),
  ('first_stage_model.decoder.up.2.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.0.conv2.bias', None),
  ('first_stage_model.decoder.up.2.block.1.norm1.weight', None),
  ('first_stage_model.decoder.up.2.block.1.norm1.bias', None),
  ('first_stage_model.decoder.up.2.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.1.conv1.bias', None),
  ('first_stage_model.decoder.up.2.block.1.norm2.weight', None),
  ('first_stage_model.decoder.up.2.block.1.norm2.bias', None),
  ('first_stage_model.decoder.up.2.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.1.conv2.bias', None),
  ('first_stage_model.decoder.up.2.block.2.norm1.weight', None),
  ('first_stage_model.decoder.up.2.block.2.norm1.bias', None),
  ('first_stage_model.decoder.up.2.block.2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.2.conv1.bias', None),
  ('first_stage_model.decoder.up.2.block.2.norm2.weight', None),
  ('first_stage_model.decoder.up.2.block.2.norm2.bias', None),
  ('first_stage_model.decoder.up.2.block.2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.block.2.conv2.bias', None),
  ('first_stage_model.decoder.up.2.upsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.2.upsample.conv.bias', None),
  ('first_stage_model.decoder.up.1.block.0.norm1.weight', None),
  ('first_stage_model.decoder.up.1.block.0.norm1.bias', None),
  ('first_stage_model.decoder.up.1.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.0.conv1.bias', None),
  ('first_stage_model.decoder.up.1.block.0.norm2.weight', None),
  ('first_stage_model.decoder.up.1.block.0.norm2.bias', None),
  ('first_stage_model.decoder.up.1.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.0.conv2.bias', None),
  ('first_stage_model.decoder.up.1.block.0.nin_shortcut.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.0.nin_shortcut.bias', None),
  ('first_stage_model.decoder.up.1.block.1.norm1.weight', None),
  ('first_stage_model.decoder.up.1.block.1.norm1.bias', None),
  ('first_stage_model.decoder.up.1.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.1.conv1.bias', None),
  ('first_stage_model.decoder.up.1.block.1.norm2.weight', None),
  ('first_stage_model.decoder.up.1.block.1.norm2.bias', None),
  ('first_stage_model.decoder.up.1.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.1.conv2.bias', None),
  ('first_stage_model.decoder.up.1.block.2.norm1.weight', None),
  ('first_stage_model.decoder.up.1.block.2.norm1.bias', None),
  ('first_stage_model.decoder.up.1.block.2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.2.conv1.bias', None),
  ('first_stage_model.decoder.up.1.block.2.norm2.weight', None),
  ('first_stage_model.decoder.up.1.block.2.norm2.bias', None),
  ('first_stage_model.decoder.up.1.block.2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.block.2.conv2.bias', None),
  ('first_stage_model.decoder.up.1.upsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.1.upsample.conv.bias', None),
  ('first_stage_model.decoder.up.0.block.0.norm1.weight', None),
  ('first_stage_model.decoder.up.0.block.0.norm1.bias', None),
  ('first_stage_model.decoder.up.0.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.0.conv1.bias', None),
  ('first_stage_model.decoder.up.0.block.0.norm2.weight', None),
  ('first_stage_model.decoder.up.0.block.0.norm2.bias', None),
  ('first_stage_model.decoder.up.0.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.0.conv2.bias', None),
  ('first_stage_model.decoder.up.0.block.0.nin_shortcut.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.0.nin_shortcut.bias', None),
  ('first_stage_model.decoder.up.0.block.1.norm1.weight', None),
  ('first_stage_model.decoder.up.0.block.1.norm1.bias', None),
  ('first_stage_model.decoder.up.0.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.1.conv1.bias', None),
  ('first_stage_model.decoder.up.0.block.1.norm2.weight', None),
  ('first_stage_model.decoder.up.0.block.1.norm2.bias', None),
  ('first_stage_model.decoder.up.0.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.1.conv2.bias', None),
  ('first_stage_model.decoder.up.0.block.2.norm1.weight', None),
  ('first_stage_model.decoder.up.0.block.2.norm1.bias', None),
  ('first_stage_model.decoder.up.0.block.2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.2.conv1.bias', None),
  ('first_stage_model.decoder.up.0.block.2.norm2.weight', None),
  ('first_stage_model.decoder.up.0.block.2.norm2.bias', None),
  ('first_stage_model.decoder.up.0.block.2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.up.0.block.2.conv2.bias', None),
  ('first_stage_model.decoder.norm_out.weight', None),
  ('first_stage_model.decoder.norm_out.bias', None),
  ('first_stage_model.decoder.conv_out.weight', (2, 3, 1, 0)),
  ('first_stage_model.decoder.conv_out.bias', None)],
 'encoder': [('first_stage_model.encoder.conv_in.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.conv_in.bias', None),
  ('first_stage_model.encoder.down.0.block.0.norm1.weight', None),
  ('first_stage_model.encoder.down.0.block.0.norm1.bias', None),
  ('first_stage_model.encoder.down.0.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.0.block.0.conv1.bias', None),
  ('first_stage_model.encoder.down.0.block.0.norm2.weight', None),
  ('first_stage_model.encoder.down.0.block.0.norm2.bias', None),
  ('first_stage_model.encoder.down.0.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.0.block.0.conv2.bias', None),
  ('first_stage_model.encoder.down.0.block.1.norm1.weight', None),
  ('first_stage_model.encoder.down.0.block.1.norm1.bias', None),
  ('first_stage_model.encoder.down.0.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.0.block.1.conv1.bias', None),
  ('first_stage_model.encoder.down.0.block.1.norm2.weight', None),
  ('first_stage_model.encoder.down.0.block.1.norm2.bias', None),
  ('first_stage_model.encoder.down.0.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.0.block.1.conv2.bias', None),
  ('first_stage_model.encoder.down.0.downsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.0.downsample.conv.bias', None),
  ('first_stage_model.encoder.down.1.block.0.norm1.weight', None),
  ('first_stage_model.encoder.down.1.block.0.norm1.bias', None),
  ('first_stage_model.encoder.down.1.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.block.0.conv1.bias', None),
  ('first_stage_model.encoder.down.1.block.0.norm2.weight', None),
  ('first_stage_model.encoder.down.1.block.0.norm2.bias', None),
  ('first_stage_model.encoder.down.1.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.block.0.conv2.bias', None),
  ('first_stage_model.encoder.down.1.block.0.nin_shortcut.weight',
   (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.block.0.nin_shortcut.bias', None),
  ('first_stage_model.encoder.down.1.block.1.norm1.weight', None),
  ('first_stage_model.encoder.down.1.block.1.norm1.bias', None),
  ('first_stage_model.encoder.down.1.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.block.1.conv1.bias', None),
  ('first_stage_model.encoder.down.1.block.1.norm2.weight', None),
  ('first_stage_model.encoder.down.1.block.1.norm2.bias', None),
  ('first_stage_model.encoder.down.1.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.block.1.conv2.bias', None),
  ('first_stage_model.encoder.down.1.downsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.1.downsample.conv.bias', None),
  ('first_stage_model.encoder.down.2.block.0.norm1.weight', None),
  ('first_stage_model.encoder.down.2.block.0.norm1.bias', None),
  ('first_stage_model.encoder.down.2.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.block.0.conv1.bias', None),
  ('first_stage_model.encoder.down.2.block.0.norm2.weight', None),
  ('first_stage_model.encoder.down.2.block.0.norm2.bias', None),
  ('first_stage_model.encoder.down.2.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.block.0.conv2.bias', None),
  ('first_stage_model.encoder.down.2.block.0.nin_shortcut.weight',
   (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.block.0.nin_shortcut.bias', None),
  ('first_stage_model.encoder.down.2.block.1.norm1.weight', None),
  ('first_stage_model.encoder.down.2.block.1.norm1.bias', None),
  ('first_stage_model.encoder.down.2.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.block.1.conv1.bias', None),
  ('first_stage_model.encoder.down.2.block.1.norm2.weight', None),
  ('first_stage_model.encoder.down.2.block.1.norm2.bias', None),
  ('first_stage_model.encoder.down.2.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.block.1.conv2.bias', None),
  ('first_stage_model.encoder.down.2.downsample.conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.2.downsample.conv.bias', None),
  ('first_stage_model.encoder.down.3.block.0.norm1.weight', None),
  ('first_stage_model.encoder.down.3.block.0.norm1.bias', None),
  ('first_stage_model.encoder.down.3.block.0.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.3.block.0.conv1.bias', None),
  ('first_stage_model.encoder.down.3.block.0.norm2.weight', None),
  ('first_stage_model.encoder.down.3.block.0.norm2.bias', None),
  ('first_stage_model.encoder.down.3.block.0.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.3.block.0.conv2.bias', None),
  ('first_stage_model.encoder.down.3.block.1.norm1.weight', None),
  ('first_stage_model.encoder.down.3.block.1.norm1.bias', None),
  ('first_stage_model.encoder.down.3.block.1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.3.block.1.conv1.bias', None),
  ('first_stage_model.encoder.down.3.block.1.norm2.weight', None),
  ('first_stage_model.encoder.down.3.block.1.norm2.bias', None),
  ('first_stage_model.encoder.down.3.block.1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.down.3.block.1.conv2.bias', None),
  ('first_stage_model.encoder.mid.block_1.norm1.weight', None),
  ('first_stage_model.encoder.mid.block_1.norm1.bias', None),
  ('first_stage_model.encoder.mid.block_1.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.block_1.conv1.bias', None),
  ('first_stage_model.encoder.mid.block_1.norm2.weight', None),
  ('first_stage_model.encoder.mid.block_1.norm2.bias', None),
  ('first_stage_model.encoder.mid.block_1.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.block_1.conv2.bias', None),
  ('first_stage_model.encoder.mid.attn_1.norm.weight', None),
  ('first_stage_model.encoder.mid.attn_1.norm.bias', None),
  ('first_stage_model.encoder.mid.attn_1.q.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.attn_1.q.bias', None),
  ('first_stage_model.encoder.mid.attn_1.k.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.attn_1.k.bias', None),
  ('first_stage_model.encoder.mid.attn_1.v.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.attn_1.v.bias', None),
  ('first_stage_model.encoder.mid.attn_1.proj_out.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.attn_1.proj_out.bias', None),
  ('first_stage_model.encoder.mid.block_2.norm1.weight', None),
  ('first_stage_model.encoder.mid.block_2.norm1.bias', None),
  ('first_stage_model.encoder.mid.block_2.conv1.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.block_2.conv1.bias', None),
  ('first_stage_model.encoder.mid.block_2.norm2.weight', None),
  ('first_stage_model.encoder.mid.block_2.norm2.bias', None),
  ('first_stage_model.encoder.mid.block_2.conv2.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.mid.block_2.conv2.bias', None),
  ('first_stage_model.encoder.norm_out.weight', None),
  ('first_stage_model.encoder.norm_out.bias', None),
  ('first_stage_model.encoder.conv_out.weight', (2, 3, 1, 0)),
  ('first_stage_model.encoder.conv_out.bias', None),
  ('first_stage_model.quant_conv.weight', (2, 3, 1, 0)),
  ('first_stage_model.quant_conv.bias', None)]}


_UNCONDITIONAL_TOKENS = [
    49406,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
    49407,
]
_ALPHAS_CUMPROD = [
    0.99915,
    0.998296,
    0.9974381,
    0.9965762,
    0.99571025,
    0.9948404,
    0.9939665,
    0.9930887,
    0.9922069,
    0.9913211,
    0.9904313,
    0.98953754,
    0.9886398,
    0.9877381,
    0.9868324,
    0.98592263,
    0.98500896,
    0.9840913,
    0.9831696,
    0.982244,
    0.98131436,
    0.9803808,
    0.97944313,
    0.97850156,
    0.977556,
    0.9766064,
    0.97565293,
    0.9746954,
    0.9737339,
    0.9727684,
    0.97179896,
    0.97082555,
    0.96984816,
    0.96886677,
    0.9678814,
    0.96689206,
    0.96589875,
    0.9649015,
    0.96390027,
    0.9628951,
    0.9618859,
    0.96087277,
    0.95985574,
    0.95883465,
    0.9578097,
    0.95678073,
    0.95574784,
    0.954711,
    0.95367026,
    0.9526256,
    0.9515769,
    0.95052433,
    0.94946784,
    0.94840735,
    0.947343,
    0.94627476,
    0.9452025,
    0.9441264,
    0.9430464,
    0.9419625,
    0.9408747,
    0.939783,
    0.9386874,
    0.93758786,
    0.9364845,
    0.93537724,
    0.9342661,
    0.9331511,
    0.9320323,
    0.9309096,
    0.929783,
    0.9286526,
    0.9275183,
    0.9263802,
    0.92523825,
    0.92409253,
    0.92294294,
    0.9217895,
    0.92063236,
    0.9194713,
    0.9183065,
    0.9171379,
    0.91596556,
    0.9147894,
    0.9136095,
    0.91242576,
    0.9112383,
    0.9100471,
    0.9088522,
    0.9076535,
    0.9064511,
    0.90524495,
    0.9040351,
    0.90282154,
    0.9016043,
    0.90038335,
    0.8991587,
    0.8979304,
    0.8966984,
    0.89546275,
    0.89422345,
    0.8929805,
    0.89173394,
    0.89048374,
    0.88922995,
    0.8879725,
    0.8867115,
    0.88544685,
    0.88417864,
    0.88290685,
    0.8816315,
    0.88035256,
    0.8790701,
    0.87778413,
    0.8764946,
    0.8752016,
    0.873905,
    0.87260497,
    0.8713014,
    0.8699944,
    0.86868393,
    0.86737,
    0.8660526,
    0.8647318,
    0.86340755,
    0.8620799,
    0.8607488,
    0.85941434,
    0.8580765,
    0.8567353,
    0.8553907,
    0.8540428,
    0.85269153,
    0.85133696,
    0.84997904,
    0.84861785,
    0.8472533,
    0.8458856,
    0.8445145,
    0.84314024,
    0.84176266,
    0.8403819,
    0.8389979,
    0.8376107,
    0.8362203,
    0.83482677,
    0.83343,
    0.8320301,
    0.8306271,
    0.8292209,
    0.82781166,
    0.82639927,
    0.8249838,
    0.82356524,
    0.8221436,
    0.82071894,
    0.81929123,
    0.81786054,
    0.8164268,
    0.8149901,
    0.8135504,
    0.81210774,
    0.81066215,
    0.8092136,
    0.8077621,
    0.80630773,
    0.80485046,
    0.8033903,
    0.80192727,
    0.8004614,
    0.79899275,
    0.79752123,
    0.7960469,
    0.7945698,
    0.7930899,
    0.79160726,
    0.7901219,
    0.7886338,
    0.787143,
    0.7856495,
    0.7841533,
    0.78265446,
    0.78115296,
    0.7796488,
    0.77814204,
    0.7766327,
    0.7751208,
    0.7736063,
    0.77208924,
    0.7705697,
    0.7690476,
    0.767523,
    0.7659959,
    0.7644664,
    0.76293445,
    0.7614,
    0.7598632,
    0.75832397,
    0.75678235,
    0.75523835,
    0.75369203,
    0.7521434,
    0.75059247,
    0.7490392,
    0.7474837,
    0.7459259,
    0.7443659,
    0.74280363,
    0.7412392,
    0.7396726,
    0.7381038,
    0.73653287,
    0.7349598,
    0.7333846,
    0.73180735,
    0.730228,
    0.7286466,
    0.7270631,
    0.7254777,
    0.72389024,
    0.72230077,
    0.7207094,
    0.71911603,
    0.7175208,
    0.7159236,
    0.71432453,
    0.7127236,
    0.71112084,
    0.7095162,
    0.7079098,
    0.7063016,
    0.70469165,
    0.70307994,
    0.7014665,
    0.69985133,
    0.6982345,
    0.696616,
    0.6949958,
    0.69337404,
    0.69175065,
    0.69012564,
    0.6884991,
    0.68687093,
    0.6852413,
    0.68361014,
    0.6819775,
    0.6803434,
    0.67870784,
    0.6770708,
    0.6754324,
    0.6737926,
    0.67215145,
    0.670509,
    0.66886514,
    0.66722,
    0.6655736,
    0.66392595,
    0.662277,
    0.6606269,
    0.65897554,
    0.657323,
    0.65566933,
    0.6540145,
    0.6523586,
    0.6507016,
    0.6490435,
    0.64738435,
    0.6457241,
    0.64406294,
    0.6424008,
    0.64073765,
    0.63907355,
    0.63740855,
    0.6357426,
    0.6340758,
    0.6324082,
    0.6307397,
    0.6290704,
    0.6274003,
    0.6257294,
    0.62405777,
    0.6223854,
    0.62071234,
    0.6190386,
    0.61736417,
    0.6156891,
    0.61401343,
    0.6123372,
    0.6106603,
    0.6089829,
    0.607305,
    0.6056265,
    0.6039476,
    0.60226816,
    0.6005883,
    0.598908,
    0.59722733,
    0.5955463,
    0.59386486,
    0.5921831,
    0.59050107,
    0.5888187,
    0.5871361,
    0.5854532,
    0.5837701,
    0.5820868,
    0.5804033,
    0.5787197,
    0.5770359,
    0.575352,
    0.57366806,
    0.571984,
    0.5702999,
    0.5686158,
    0.56693166,
    0.56524754,
    0.5635635,
    0.5618795,
    0.56019557,
    0.5585118,
    0.5568281,
    0.55514455,
    0.5534612,
    0.551778,
    0.5500951,
    0.5484124,
    0.54673,
    0.5450478,
    0.54336596,
    0.54168445,
    0.54000324,
    0.53832245,
    0.5366421,
    0.53496206,
    0.5332825,
    0.53160346,
    0.5299248,
    0.52824676,
    0.5265692,
    0.52489215,
    0.5232157,
    0.5215398,
    0.51986456,
    0.51818997,
    0.51651603,
    0.51484275,
    0.5131702,
    0.5114983,
    0.5098272,
    0.50815684,
    0.5064873,
    0.50481856,
    0.50315064,
    0.50148356,
    0.4998174,
    0.4981521,
    0.49648774,
    0.49482432,
    0.49316183,
    0.49150035,
    0.48983985,
    0.4881804,
    0.486522,
    0.48486462,
    0.4832084,
    0.48155323,
    0.4798992,
    0.47824633,
    0.47659463,
    0.4749441,
    0.47329482,
    0.4716468,
    0.47,
    0.46835446,
    0.46671024,
    0.46506736,
    0.4634258,
    0.46178558,
    0.46014675,
    0.45850933,
    0.45687333,
    0.45523876,
    0.45360568,
    0.45197406,
    0.45034397,
    0.44871536,
    0.44708833,
    0.44546285,
    0.44383895,
    0.44221666,
    0.440596,
    0.43897697,
    0.43735963,
    0.43574396,
    0.43412998,
    0.43251774,
    0.43090722,
    0.4292985,
    0.42769152,
    0.42608637,
    0.42448303,
    0.4228815,
    0.42128187,
    0.4196841,
    0.41808826,
    0.4164943,
    0.4149023,
    0.41331223,
    0.41172415,
    0.41013804,
    0.40855396,
    0.4069719,
    0.4053919,
    0.40381396,
    0.4022381,
    0.40066436,
    0.39909273,
    0.39752322,
    0.3959559,
    0.39439073,
    0.39282778,
    0.39126703,
    0.3897085,
    0.3881522,
    0.3865982,
    0.38504648,
    0.38349706,
    0.38194993,
    0.38040516,
    0.37886274,
    0.37732267,
    0.375785,
    0.37424973,
    0.37271687,
    0.37118647,
    0.36965853,
    0.36813304,
    0.36661002,
    0.36508954,
    0.36357155,
    0.3620561,
    0.36054322,
    0.3590329,
    0.35752517,
    0.35602003,
    0.35451752,
    0.35301763,
    0.3515204,
    0.3500258,
    0.3485339,
    0.3470447,
    0.34555823,
    0.34407446,
    0.34259343,
    0.34111515,
    0.33963963,
    0.33816692,
    0.336697,
    0.3352299,
    0.33376563,
    0.3323042,
    0.33084565,
    0.32938993,
    0.32793713,
    0.3264872,
    0.32504022,
    0.32359615,
    0.32215503,
    0.32071686,
    0.31928164,
    0.31784943,
    0.3164202,
    0.314994,
    0.3135708,
    0.31215066,
    0.31073356,
    0.3093195,
    0.30790854,
    0.30650064,
    0.30509588,
    0.30369422,
    0.30229566,
    0.30090025,
    0.299508,
    0.2981189,
    0.29673296,
    0.29535022,
    0.2939707,
    0.29259437,
    0.29122123,
    0.28985137,
    0.28848472,
    0.28712133,
    0.2857612,
    0.28440437,
    0.2830508,
    0.28170055,
    0.2803536,
    0.27900997,
    0.27766964,
    0.27633268,
    0.27499905,
    0.2736688,
    0.27234194,
    0.27101842,
    0.2696983,
    0.26838157,
    0.26706827,
    0.26575837,
    0.26445192,
    0.26314887,
    0.2618493,
    0.26055318,
    0.2592605,
    0.25797132,
    0.2566856,
    0.2554034,
    0.25412467,
    0.25284946,
    0.25157773,
    0.2503096,
    0.24904492,
    0.24778382,
    0.24652626,
    0.24527225,
    0.2440218,
    0.24277493,
    0.24153163,
    0.24029191,
    0.23905578,
    0.23782326,
    0.23659433,
    0.23536903,
    0.23414734,
    0.23292927,
    0.23171483,
    0.23050404,
    0.22929688,
    0.22809339,
    0.22689353,
    0.22569734,
    0.22450483,
    0.22331597,
    0.2221308,
    0.22094932,
    0.21977153,
    0.21859743,
    0.21742703,
    0.21626033,
    0.21509734,
    0.21393807,
    0.21278252,
    0.21163069,
    0.21048258,
    0.20933822,
    0.20819758,
    0.2070607,
    0.20592754,
    0.20479813,
    0.20367248,
    0.20255059,
    0.20143245,
    0.20031808,
    0.19920748,
    0.19810064,
    0.19699757,
    0.19589828,
    0.19480278,
    0.19371104,
    0.1926231,
    0.19153893,
    0.19045855,
    0.18938197,
    0.18830918,
    0.18724018,
    0.18617497,
    0.18511358,
    0.18405597,
    0.18300217,
    0.18195218,
    0.18090598,
    0.1798636,
    0.17882504,
    0.17779027,
    0.1767593,
    0.17573217,
    0.17470883,
    0.1736893,
    0.1726736,
    0.1716617,
    0.17065361,
    0.16964935,
    0.1686489,
    0.16765225,
    0.16665943,
    0.16567042,
    0.16468522,
    0.16370384,
    0.16272627,
    0.16175252,
    0.16078258,
    0.15981644,
    0.15885411,
    0.1578956,
    0.15694089,
    0.15599,
    0.15504292,
    0.15409963,
    0.15316014,
    0.15222447,
    0.15129258,
    0.1503645,
    0.14944021,
    0.14851972,
    0.14760303,
    0.14669013,
    0.14578101,
    0.14487568,
    0.14397413,
    0.14307636,
    0.14218238,
    0.14129217,
    0.14040573,
    0.13952307,
    0.13864417,
    0.13776903,
    0.13689767,
    0.13603005,
    0.13516618,
    0.13430607,
    0.13344972,
    0.1325971,
    0.13174823,
    0.1309031,
    0.13006169,
    0.12922402,
    0.12839006,
    0.12755983,
    0.12673332,
    0.12591052,
    0.12509143,
    0.12427604,
    0.12346435,
    0.12265636,
    0.121852055,
    0.12105144,
    0.1202545,
    0.11946124,
    0.11867165,
    0.11788572,
    0.11710346,
    0.11632485,
    0.115549885,
    0.11477857,
    0.11401089,
    0.11324684,
    0.11248643,
    0.11172963,
    0.11097645,
    0.110226884,
    0.10948092,
    0.10873855,
    0.10799977,
    0.107264586,
    0.106532976,
    0.105804935,
    0.10508047,
    0.10435956,
    0.1036422,
    0.10292839,
    0.10221813,
    0.1015114,
    0.10080819,
    0.100108504,
    0.09941233,
    0.098719664,
    0.0980305,
    0.09734483,
    0.09666264,
    0.09598393,
    0.095308684,
    0.09463691,
    0.093968585,
    0.09330372,
    0.092642285,
    0.09198428,
    0.09132971,
    0.09067855,
    0.090030804,
    0.089386456,
    0.088745505,
    0.088107936,
    0.08747375,
    0.08684293,
    0.08621547,
    0.085591376,
    0.084970616,
    0.08435319,
    0.0837391,
    0.08312833,
    0.08252087,
    0.08191671,
    0.08131585,
    0.08071827,
    0.080123976,
    0.07953294,
    0.078945175,
    0.078360654,
    0.077779375,
    0.07720133,
    0.07662651,
    0.07605491,
    0.07548651,
    0.07492131,
    0.0743593,
    0.07380046,
    0.073244795,
    0.07269229,
    0.07214294,
    0.07159673,
    0.07105365,
    0.070513695,
    0.06997685,
    0.069443114,
    0.06891247,
    0.06838491,
    0.067860425,
    0.06733901,
    0.066820644,
    0.06630533,
    0.06579305,
    0.0652838,
    0.06477757,
    0.06427433,
    0.0637741,
    0.063276865,
    0.06278259,
    0.062291294,
    0.061802953,
    0.06131756,
    0.0608351,
    0.060355574,
    0.05987896,
    0.059405252,
    0.058934443,
    0.05846652,
    0.058001474,
    0.057539295,
    0.05707997,
    0.056623492,
    0.05616985,
    0.05571903,
    0.055271026,
    0.054825824,
    0.05438342,
    0.053943794,
    0.053506944,
    0.05307286,
    0.052641522,
    0.052212927,
    0.051787063,
    0.051363923,
    0.05094349,
    0.050525755,
    0.05011071,
    0.04969834,
    0.049288645,
    0.0488816,
    0.048477206,
    0.048075445,
    0.04767631,
    0.047279786,
    0.04688587,
    0.046494544,
    0.046105802,
    0.04571963,
    0.04533602,
    0.04495496,
    0.04457644,
    0.044200446,
    0.04382697,
    0.043456003,
    0.043087535,
    0.042721547,
    0.042358037,
    0.04199699,
    0.041638397,
    0.041282244,
    0.040928524,
    0.040577225,
    0.040228333,
    0.039881844,
    0.039537743,
    0.039196018,
    0.038856663,
    0.038519662,
    0.038185004,
    0.037852682,
    0.037522685,
    0.037195,
    0.036869615,
    0.036546525,
    0.036225714,
    0.03590717,
    0.035590887,
    0.035276853,
    0.034965057,
    0.034655485,
    0.03434813,
    0.03404298,
    0.033740025,
    0.033439253,
    0.033140652,
    0.032844216,
    0.03254993,
    0.032257784,
    0.03196777,
    0.031679876,
    0.031394087,
    0.031110398,
    0.030828796,
    0.030549273,
    0.030271813,
    0.02999641,
    0.029723052,
    0.029451728,
    0.029182427,
    0.02891514,
    0.028649855,
    0.028386563,
    0.028125253,
    0.02786591,
    0.027608532,
    0.027353102,
    0.027099613,
    0.026848052,
    0.026598409,
    0.026350675,
    0.02610484,
    0.02586089,
    0.02561882,
    0.025378617,
    0.025140269,
    0.024903767,
    0.0246691,
    0.02443626,
    0.024205236,
    0.023976017,
    0.023748592,
    0.023522953,
    0.023299087,
    0.023076987,
    0.022856642,
    0.02263804,
    0.022421172,
    0.022206029,
    0.0219926,
    0.021780876,
    0.021570845,
    0.021362498,
    0.021155827,
    0.020950818,
    0.020747466,
    0.020545758,
    0.020345684,
    0.020147236,
    0.019950403,
    0.019755175,
    0.019561544,
    0.019369498,
    0.019179028,
    0.018990126,
    0.01880278,
    0.018616982,
    0.018432721,
    0.01824999,
    0.018068777,
    0.017889075,
    0.017710872,
    0.01753416,
    0.017358929,
    0.017185168,
    0.017012872,
    0.016842028,
    0.016672628,
    0.016504662,
    0.016338123,
    0.016173,
    0.016009282,
    0.015846964,
    0.015686033,
    0.015526483,
    0.015368304,
    0.015211486,
    0.0150560215,
    0.014901901,
    0.014749114,
    0.014597654,
    0.014447511,
    0.0142986765,
    0.014151142,
    0.014004898,
    0.013859936,
    0.013716248,
    0.0135738235,
    0.013432656,
    0.013292736,
    0.013154055,
    0.013016605,
    0.012880377,
    0.012745362,
    0.012611552,
    0.012478939,
    0.012347515,
    0.01221727,
    0.012088198,
    0.0119602885,
    0.0118335355,
    0.011707929,
    0.011583461,
    0.011460125,
    0.011337912,
    0.011216813,
    0.011096821,
    0.010977928,
    0.0108601255,
    0.010743406,
    0.010627762,
    0.0105131855,
    0.010399668,
    0.010287202,
    0.01017578,
    0.010065395,
    0.009956039,
    0.009847702,
    0.009740381,
    0.0096340645,
    0.009528747,
    0.009424419,
    0.009321076,
    0.009218709,
    0.00911731,
    0.009016872,
    0.008917389,
    0.008818853,
    0.008721256,
    0.008624591,
    0.008528852,
    0.00843403,
    0.00834012,
    0.008247114,
    0.008155004,
    0.008063785,
    0.007973449,
    0.007883989,
    0.007795398,
    0.0077076694,
    0.0076207966,
    0.0075347726,
    0.007449591,
    0.0073652444,
    0.007281727,
    0.0071990318,
    0.007117152,
    0.0070360815,
    0.0069558136,
    0.0068763415,
    0.006797659,
    0.00671976,
    0.0066426382,
    0.0065662866,
    0.006490699,
    0.0064158696,
    0.006341792,
    0.00626846,
    0.0061958674,
    0.0061240084,
    0.0060528764,
    0.0059824656,
    0.0059127696,
    0.0058437833,
    0.0057755,
    0.0057079145,
    0.00564102,
    0.0055748112,
    0.0055092825,
    0.005444428,
    0.005380241,
    0.0053167176,
    0.005253851,
    0.005191636,
    0.005130066,
    0.0050691366,
    0.0050088423,
    0.0049491767,
    0.004890135,
    0.0048317118,
    0.004773902,
    0.004716699,
    0.0046600983,
]


In [4]:
class PaddedConv2D(keras.layers.Layer):
    def __init__(self, channels, kernel_size, padding=0, stride=1):
        super().__init__()
        self.padding2d = keras.layers.ZeroPadding2D((padding, padding))
        self.conv2d = keras.layers.Conv2D(
            channels, kernel_size, strides=(stride, stride)
        )

    def call(self, x):
        x = self.padding2d(x)
        return self.conv2d(x)


class GEGLU(keras.layers.Layer):
    def __init__(self, dim_out):
        super().__init__()
        self.proj = keras.layers.Dense(dim_out * 2)
        self.dim_out = dim_out

    def call(self, x):
        xp = self.proj(x)
        x, gate = xp[..., : self.dim_out], xp[..., self.dim_out :]
        return x * gelu(gate)


def gelu(x):
    tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2)))
    return 0.5 * x * (1 + tanh_res)


def quick_gelu(x):
    return x * tf.sigmoid(x * 1.702)


def apply_seq(x, layers):
    for l in layers:
        x = l(x)
    return x


def td_dot(a, b):
    aa = tf.reshape(a, (-1, a.shape[2], a.shape[3]))
    bb = tf.reshape(b, (-1, b.shape[2], b.shape[3]))
    cc = keras.backend.batch_dot(aa, bb)
    return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))


#AUTO ENCODER


##Attention Block

In [5]:
class AttentionBlock(keras.layers.Layer):
    def __init__(self, channels):
        super().__init__()
        self.norm = tfa.layers.GroupNormalization(epsilon=1e-5)
        self.q = PaddedConv2D(channels, 1)
        self.k = PaddedConv2D(channels, 1)
        self.v = PaddedConv2D(channels, 1)
        self.proj_out = PaddedConv2D(channels, 1)

    def call(self, x):
        h_ = self.norm(x)
        q, k, v = self.q(h_), self.k(h_), self.v(h_)

        # Compute attention
        b, h, w, c = q.shape
        q = tf.reshape(q, (-1, h * w, c))  # b,hw,c
        k = keras.layers.Permute((3, 1, 2))(k)
        k = tf.reshape(k, (-1, c, h * w))  # b,c,hw
        w_ = q @ k
        w_ = w_ * (c ** (-0.5))
        w_ = keras.activations.softmax(w_)

        # Attend to values
        v = keras.layers.Permute((3, 1, 2))(v)
        v = tf.reshape(v, (-1, c, h * w))
        w_ = keras.layers.Permute((2, 1))(w_)
        h_ = v @ w_
        h_ = keras.layers.Permute((2, 1))(h_)
        h_ = tf.reshape(h_, (-1, h, w, c))
        return x + self.proj_out(h_)



##ResnetBlock

In [6]:
class ResnetBlock(keras.layers.Layer):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5)
        self.conv1 = PaddedConv2D(out_channels, 3, padding=1)
        self.norm2 = keras.layers.GroupNormalization(epsilon=1e-5)
        self.conv2 = PaddedConv2D(out_channels, 3, padding=1)
        self.nin_shortcut = (
            PaddedConv2D(out_channels, 1)
            if in_channels != out_channels
            else lambda x: x
        )

    def call(self, x):
        h = self.conv1(keras.activations.swish(self.norm1(x)))
        h = self.conv2(keras.activations.swish(self.norm2(h)))
        return self.nin_shortcut(x) + h


##Encoder

In [7]:
class Encoder(keras.Sequential):
    def __init__(self):
        super().__init__(
            [
                PaddedConv2D(128, 3, padding=1),
                ResnetBlock(128, 128),
                ResnetBlock(128, 128),
                PaddedConv2D(128, 3, padding=(0, 1), stride=2),
                ResnetBlock(128, 256),
                ResnetBlock(256, 256),
                PaddedConv2D(256, 3, padding=(0, 1), stride=2),
                ResnetBlock(256, 512),
                ResnetBlock(512, 512),
                PaddedConv2D(512, 3, padding=(0, 1), stride=2),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                AttentionBlock(512),
                ResnetBlock(512, 512),
                tfa.layers.GroupNormalization(epsilon=1e-5),
                keras.layers.Activation("swish"),
                PaddedConv2D(8, 3, padding=1),
                PaddedConv2D(8, 1),
                keras.layers.Lambda(lambda x: x[..., :4] * 0.18215),
            ]
        )


##Decoder

In [8]:
class Decoder(keras.Sequential):
    def __init__(self):
        super().__init__(
            [
                keras.layers.Lambda(lambda x: 1 / 0.18215 * x),
                PaddedConv2D(4, 1),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(512, 512),
                AttentionBlock(512),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                keras.layers.UpSampling2D(size=(2, 2)),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                ResnetBlock(512, 512),
                keras.layers.UpSampling2D(size=(2, 2)),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(512, 256),
                ResnetBlock(256, 256),
                ResnetBlock(256, 256),
                keras.layers.UpSampling2D(size=(2, 2)),
                PaddedConv2D(256, 3, padding=1),
                ResnetBlock(256, 128),
                ResnetBlock(128, 128),
                ResnetBlock(128, 128),
                tfa.layers.GroupNormalization(epsilon=1e-5),
                keras.layers.Activation("swish"),
                PaddedConv2D(3, 3, padding=1),
            ]
        )



# CLIP

## Clip attention

In [9]:
class CLIPAttention(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.embed_dim = 768
        self.num_heads = 12
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = self.head_dim**-0.5
        self.q_proj = keras.layers.Dense(self.embed_dim)
        self.k_proj = keras.layers.Dense(self.embed_dim)
        self.v_proj = keras.layers.Dense(self.embed_dim)
        self.out_proj = keras.layers.Dense(self.embed_dim)

    def _shape(self, tensor, seq_len: int, bsz: int):
        a = tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim))
        return keras.layers.Permute((2, 1, 3))(a)  # bs , n_head , seq_len , head_dim

    def call(self, inputs):
        hidden_states, causal_attention_mask = inputs
        bsz, tgt_len, embed_dim = hidden_states.shape
        query_states = self.q_proj(hidden_states) * self.scale
        key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1)
        value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1)

        proj_shape = (-1, tgt_len, self.head_dim)
        query_states = self._shape(query_states, tgt_len, -1)
        query_states = tf.reshape(query_states, proj_shape)
        key_states = tf.reshape(key_states, proj_shape)

        src_len = tgt_len
        value_states = tf.reshape(value_states, proj_shape)
        attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states)

        attn_weights = tf.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len))
        attn_weights = attn_weights + causal_attention_mask
        attn_weights = tf.reshape(attn_weights, (-1, tgt_len, src_len))

        attn_weights = tf.nn.softmax(attn_weights)
        attn_output = attn_weights @ value_states

        attn_output = tf.reshape(
            attn_output, (-1, self.num_heads, tgt_len, self.head_dim)
        )
        attn_output = keras.layers.Permute((2, 1, 3))(attn_output)
        attn_output = tf.reshape(attn_output, (-1, tgt_len, embed_dim))

        return self.out_proj(attn_output)


## CLIP Encoder Layer

In [10]:
class CLIPEncoderLayer(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.self_attn = CLIPAttention()
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.fc1 = keras.layers.Dense(3072)
        self.fc2 = keras.layers.Dense(768)

    def call(self, inputs):
        hidden_states, causal_attention_mask = inputs
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn([hidden_states, causal_attention_mask])
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)

        hidden_states = self.fc1(hidden_states)
        hidden_states = quick_gelu(hidden_states)
        hidden_states = self.fc2(hidden_states)

        return residual + hidden_states


##CLIPEncoder

In [11]:
class CLIPEncoder(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.layers = [CLIPEncoderLayer() for i in range(12)]

    def call(self, inputs):
        [hidden_states, causal_attention_mask] = inputs
        for l in self.layers:
            hidden_states = l([hidden_states, causal_attention_mask])
        return hidden_states



##Clip Text embedings

In [12]:
class CLIPTextEmbeddings(keras.layers.Layer):
    def __init__(self, n_words=77):
        super().__init__()
        self.token_embedding_layer = keras.layers.Embedding(
            49408, 768, name="token_embedding"
        )
        self.position_embedding_layer = keras.layers.Embedding(
            n_words, 768, name="position_embedding"
        )

    def call(self, inputs):
        input_ids, position_ids = inputs
        word_embeddings = self.token_embedding_layer(input_ids)
        position_embeddings = self.position_embedding_layer(position_ids)
        return word_embeddings + position_embeddings



##CLIP Text Transformer

In [13]:
class CLIPTextTransformer(keras.models.Model):
    def __init__(self, n_words=77):
        super().__init__()
        self.embeddings = CLIPTextEmbeddings(n_words=n_words)
        self.encoder = CLIPEncoder()
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5)
        self.causal_attention_mask = tf.constant(
            np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1)
        )

    def call(self, inputs):
        input_ids, position_ids = inputs
        x = self.embeddings([input_ids, position_ids])
        x = self.encoder([x, self.causal_attention_mask])
        return self.final_layer_norm(x)


# Diiffution Model

In [14]:
class ResBlock(keras.layers.Layer):
    def __init__(self, channels, out_channels):
        super().__init__()
        self.in_layers = [
            tfa.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            PaddedConv2D(out_channels, 3, padding=1),
        ]
        self.emb_layers = [
            keras.activations.swish,
            keras.layers.Dense(out_channels),
        ]
        self.out_layers = [
            tfa.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            PaddedConv2D(out_channels, 3, padding=1),
        ]
        self.skip_connection = (
            PaddedConv2D(out_channels, 1) if channels != out_channels else lambda x: x
        )

    def call(self, inputs):
        x, emb = inputs
        h = apply_seq(x, self.in_layers)
        emb_out = apply_seq(emb, self.emb_layers)
        h = h + emb_out[:, None, None]
        h = apply_seq(h, self.out_layers)
        ret = self.skip_connection(x) + h
        return ret



In [15]:
class CrossAttention(keras.layers.Layer):
    def __init__(self, n_heads, d_head):
        super().__init__()
        self.to_q = keras.layers.Dense(n_heads * d_head, use_bias=False)
        self.to_k = keras.layers.Dense(n_heads * d_head, use_bias=False)
        self.to_v = keras.layers.Dense(n_heads * d_head, use_bias=False)
        self.scale = d_head**-0.5
        self.num_heads = n_heads
        self.head_size = d_head
        self.to_out = [keras.layers.Dense(n_heads * d_head)]

    def call(self, inputs):
        assert type(inputs) is list
        if len(inputs) == 1:
            inputs = inputs + [None]
        x, context = inputs
        context = x if context is None else context
        q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
        assert len(x.shape) == 3
        q = tf.reshape(q, (-1, x.shape[1], self.num_heads, self.head_size))
        k = tf.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size))
        v = tf.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size))

        q = keras.layers.Permute((2, 1, 3))(q)  # (bs, num_heads, time, head_size)
        k = keras.layers.Permute((2, 3, 1))(k)  # (bs, num_heads, head_size, time)
        v = keras.layers.Permute((2, 1, 3))(v)  # (bs, num_heads, time, head_size)

        score = td_dot(q, k) * self.scale
        weights = keras.activations.softmax(score)  # (bs, num_heads, time, time)
        attention = td_dot(weights, v)
        attention = keras.layers.Permute((2, 1, 3))(
            attention
        )  # (bs, time, num_heads, head_size)
        h_ = tf.reshape(attention, (-1, x.shape[1], self.num_heads * self.head_size))
        return apply_seq(h_, self.to_out)



In [16]:
class BasicTransformerBlock(keras.layers.Layer):
    def __init__(self, dim, n_heads, d_head):
        super().__init__()
        self.norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn1 = CrossAttention(n_heads, d_head)

        self.norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn2 = CrossAttention(n_heads, d_head)

        self.norm3 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.geglu = GEGLU(dim * 4)
        self.dense = keras.layers.Dense(dim)

    def call(self, inputs):
        x, context = inputs
        x = self.attn1([self.norm1(x)]) + x
        x = self.attn2([self.norm2(x), context]) + x
        return self.dense(self.geglu(self.norm3(x))) + x


In [17]:
class SpatialTransformer(keras.layers.Layer):
    def __init__(self, channels, n_heads, d_head):
        super().__init__()
        self.norm = tfa.layers.GroupNormalization(epsilon=1e-5)
        assert channels == n_heads * d_head
        self.proj_in = PaddedConv2D(n_heads * d_head, 1)
        self.transformer_blocks = [BasicTransformerBlock(channels, n_heads, d_head)]
        self.proj_out = PaddedConv2D(channels, 1)

    def call(self, inputs):
        x, context = inputs
        b, h, w, c = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = tf.reshape(x, (-1, h * w, c))
        for block in self.transformer_blocks:
            x = block([x, context])
        x = tf.reshape(x, (-1, h, w, c))
        return self.proj_out(x) + x_in



In [18]:
class Downsample(keras.layers.Layer):
    def __init__(self, channels):
        super().__init__()
        self.op = PaddedConv2D(channels, 3, stride=2, padding=1)

    def call(self, x):
        return self.op(x)



In [19]:
class Upsample(keras.layers.Layer):
    def __init__(self, channels):
        super().__init__()
        self.ups = keras.layers.UpSampling2D(size=(2, 2))
        self.conv = PaddedConv2D(channels, 3, padding=1)

    def call(self, x):
        x = self.ups(x)
        return self.conv(x)



In [20]:

class UNetModel(keras.models.Model):
    def __init__(self):
        super().__init__()
        self.time_embed = [
            keras.layers.Dense(1280),
            keras.activations.swish,
            keras.layers.Dense(1280),
        ]
        self.input_blocks = [
            [PaddedConv2D(320, kernel_size=3, padding=1)],
            [ResBlock(320, 320), SpatialTransformer(320, 8, 40)],
            [ResBlock(320, 320), SpatialTransformer(320, 8, 40)],
            [Downsample(320)],
            [ResBlock(320, 640), SpatialTransformer(640, 8, 80)],
            [ResBlock(640, 640), SpatialTransformer(640, 8, 80)],
            [Downsample(640)],
            [ResBlock(640, 1280), SpatialTransformer(1280, 8, 160)],
            [ResBlock(1280, 1280), SpatialTransformer(1280, 8, 160)],
            [Downsample(1280)],
            [ResBlock(1280, 1280)],
            [ResBlock(1280, 1280)],
        ]
        self.middle_block = [
            ResBlock(1280, 1280),
            SpatialTransformer(1280, 8, 160),
            ResBlock(1280, 1280),
        ]
        self.output_blocks = [
            [ResBlock(2560, 1280)],
            [ResBlock(2560, 1280)],
            [ResBlock(2560, 1280), Upsample(1280)],
            [ResBlock(2560, 1280), SpatialTransformer(1280, 8, 160)],
            [ResBlock(2560, 1280), SpatialTransformer(1280, 8, 160)],
            [
                ResBlock(1920, 1280),
                SpatialTransformer(1280, 8, 160),
                Upsample(1280),
            ],
            [ResBlock(1920, 640), SpatialTransformer(640, 8, 80)],  # 6
            [ResBlock(1280, 640), SpatialTransformer(640, 8, 80)],
            [
                ResBlock(960, 640),
                SpatialTransformer(640, 8, 80),
                Upsample(640),
            ],
            [ResBlock(960, 320), SpatialTransformer(320, 8, 40)],
            [ResBlock(640, 320), SpatialTransformer(320, 8, 40)],
            [ResBlock(640, 320), SpatialTransformer(320, 8, 40)],
        ]
        self.out = [
            tfa.layers.GroupNormalization(epsilon=1e-5),
            keras.activations.swish,
            PaddedConv2D(4, kernel_size=3, padding=1),
        ]

    def call(self, inputs):
        x, t_emb, context = inputs
        emb = apply_seq(t_emb, self.time_embed)

        def apply(x, layer):
            if isinstance(layer, ResBlock):
                x = layer([x, emb])
            elif isinstance(layer, SpatialTransformer):
                x = layer([x, context])
            else:
                x = layer(x)
            return x

        saved_inputs = []
        for b in self.input_blocks:
            for layer in b:
                x = apply(x, layer)
            saved_inputs.append(x)

        for layer in self.middle_block:
            x = apply(x, layer)

        for b in self.output_blocks:
            x = tf.concat([x, saved_inputs.pop()], axis=-1)
            for layer in b:
                x = apply(x, layer)
        return apply_seq(x, self.out)


# CLIP tokenizor

In [21]:
import zipfile


#Unzip the downloaded file
zip_ref = zipfile.ZipFile("/content/drive/MyDrive/clip_tokenizer.zip", "r")
zip_ref.extractall()
zip_ref.close()


# stable diffusion

In [22]:
from clip_tokenizer import SimpleTokenizer

In [23]:
MAX_TEXT_LEN = 77

class StableDiffusion:
    def __init__(
        self, img_height=1000, img_width=1000, jit_compile=False, download_weights=True
    ):
        self.img_height = img_height
        self.img_width = img_width
        self.tokenizer = SimpleTokenizer()

        text_encoder, diffusion_model, decoder, encoder = get_models(
            img_height, img_width, download_weights=download_weights
        )
        self.text_encoder = text_encoder
        self.diffusion_model = diffusion_model
        self.decoder = decoder
        self.encoder = encoder

        if jit_compile:
            self.text_encoder.compile(jit_compile=True)
            self.diffusion_model.compile(jit_compile=True)
            self.decoder.compile(jit_compile=True)
            self.encoder.compile(jit_compile=True)

        self.dtype = tf.float32
        if tf.keras.mixed_precision.global_policy().name == "mixed_float16":
            self.dtype = tf.float16

    def generate(
        self,
        prompt,
        negative_prompt=None,
        batch_size=1,
        num_steps=25,
        unconditional_guidance_scale=7.5,
        temperature=1,
        seed=None,
        input_image=None,
        input_mask=None,
        input_image_strength=0.5,
    ):
        # Tokenize prompt (i.e. starting context)
        inputs = self.tokenizer.encode(prompt)
        assert len(inputs) < 77, "Prompt is too long (should be < 77 tokens)"
        phrase = inputs + [49407] * (77 - len(inputs))
        phrase = np.array(phrase)[None].astype("int32")
        phrase = np.repeat(phrase, batch_size, axis=0)

        # Encode prompt tokens (and their positions) into a "context vector"
        pos_ids = np.array(list(range(77)))[None].astype("int32")
        pos_ids = np.repeat(pos_ids, batch_size, axis=0)
        context = self.text_encoder.predict_on_batch([phrase, pos_ids])

        input_image_tensor = None
        if input_image is not None:
            if type(input_image) is str:
                input_image = Image.open(input_image)
                input_image = input_image.resize((self.img_width, self.img_height))

            elif type(input_image) is np.ndarray:
                input_image = np.resize(
                    input_image, (self.img_height, self.img_width, input_image.shape[2])
                )

            input_image_array = np.array(input_image, dtype=np.float32)[None, ..., :3]
            input_image_tensor = tf.cast(
                (input_image_array / 255.0) * 2 - 1, self.dtype
            )

        if type(input_mask) is str:
            input_mask = Image.open(input_mask)
            input_mask = input_mask.resize((self.img_width, self.img_height))
            input_mask_array = np.array(input_mask, dtype=np.float32)[None, ..., None]
            input_mask_array = input_mask_array / 255.0

            latent_mask = input_mask.resize((self.img_width // 8, self.img_height // 8))
            latent_mask = np.array(latent_mask, dtype=np.float32)[None, ..., None]
            latent_mask = 1 - (latent_mask.astype("float") / 255.0)
            latent_mask_tensor = tf.cast(
                tf.repeat(latent_mask, batch_size, axis=0), self.dtype
            )

        # Tokenize negative prompt or use default padding tokens
        unconditional_tokens = _UNCONDITIONAL_TOKENS
        if negative_prompt is not None:
            inputs = self.tokenizer.encode(negative_prompt)
            assert (
                len(inputs) < 77
            ), "Negative prompt is too long (should be < 77 tokens)"
            unconditional_tokens = inputs + [49407] * (77 - len(inputs))

        # Encode unconditional tokens (and their positions into an
        # "unconditional context vector"
        unconditional_tokens = np.array(unconditional_tokens)[None].astype("int32")
        unconditional_tokens = np.repeat(unconditional_tokens, batch_size, axis=0)
        unconditional_context = self.text_encoder.predict_on_batch(
            [unconditional_tokens, pos_ids]
        )
        timesteps = np.arange(1, 1000, 1000 // num_steps)
        input_img_noise_t = timesteps[int(len(timesteps) * input_image_strength)]
        latent, alphas, alphas_prev = self.get_starting_parameters(
            timesteps,
            batch_size,
            seed,
            input_image=input_image_tensor,
            input_img_noise_t=input_img_noise_t,
        )

        if input_image is not None:
            timesteps = timesteps[: int(len(timesteps) * input_image_strength)]

        # Diffusion stage
        progbar = tqdm(list(enumerate(timesteps))[::-1])
        for index, timestep in progbar:
            progbar.set_description(f"{index:3d} {timestep:3d}")
            e_t = self.get_model_output(
                latent,
                timestep,
                context,
                unconditional_context,
                unconditional_guidance_scale,
                batch_size,
            )
            a_t, a_prev = alphas[index], alphas_prev[index]
            latent, pred_x0 = self.get_x_prev_and_pred_x0(
                latent, e_t, index, a_t, a_prev, temperature, seed
            )

            if input_mask is not None and input_image is not None:
                # If mask is provided, noise at current timestep will be added to input image.
                # The intermediate latent will be merged with input latent.
                latent_orgin, alphas, alphas_prev = self.get_starting_parameters(
                    timesteps,
                    batch_size,
                    seed,
                    input_image=input_image_tensor,
                    input_img_noise_t=timestep,
                )
                latent = latent_orgin * latent_mask_tensor + latent * (
                    1 - latent_mask_tensor
                )

        # Decoding stage
        decoded = self.decoder.predict_on_batch(latent)
        decoded = ((decoded + 1) / 2) * 255

        if input_mask is not None:
            # Merge inpainting output with original image
            decoded = (
                input_image_array * (1 - input_mask_array)
                + np.array(decoded) * input_mask_array
            )

        return np.clip(decoded, 0, 255).astype("uint8")

    def timestep_embedding(self, timesteps, dim=320, max_period=10000):
        half = dim // 2
        freqs = np.exp(
            -math.log(max_period) * np.arange(0, half, dtype="float32") / half
        )
        args = np.array(timesteps) * freqs
        embedding = np.concatenate([np.cos(args), np.sin(args)])
        return tf.convert_to_tensor(embedding.reshape(1, -1), dtype=self.dtype)

    def add_noise(self, x, t, noise=None):
        batch_size, w, h = x.shape[0], x.shape[1], x.shape[2]
        if noise is None:
            noise = tf.random.normal((batch_size, w, h, 4), dtype=self.dtype)
        sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
        sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5

        return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise

    def get_starting_parameters(
        self, timesteps, batch_size, seed, input_image=None, input_img_noise_t=None
    ):
        n_h = self.img_height // 8
        n_w = self.img_width // 8
        alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
        alphas_prev = [1.0] + alphas[:-1]
        if input_image is None:
            latent = tf.random.normal((batch_size, n_h, n_w, 4), seed=seed)
        else:
            latent = self.encoder(input_image)
            latent = tf.repeat(latent, batch_size, axis=0)
            latent = self.add_noise(latent, input_img_noise_t)
        return latent, alphas, alphas_prev

    def get_model_output(
        self,
        latent,
        t,
        context,
        unconditional_context,
        unconditional_guidance_scale,
        batch_size,
    ):
        timesteps = np.array([t])
        t_emb = self.timestep_embedding(timesteps)
        t_emb = np.repeat(t_emb, batch_size, axis=0)
        unconditional_latent = self.diffusion_model.predict_on_batch(
            [latent, t_emb, unconditional_context]
        )
        latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
        return unconditional_latent + unconditional_guidance_scale * (
            latent - unconditional_latent
        )

    def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed):
        sigma_t = 0
        sqrt_one_minus_at = math.sqrt(1 - a_t)
        pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)

        # Direction pointing to x_t
        dir_xt = math.sqrt(1.0 - a_prev - sigma_t**2) * e_t
        noise = sigma_t * tf.random.normal(x.shape, seed=seed) * temperature
        x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt
        return x_prev, pred_x0

    def load_weights_from_pytorch_ckpt(self, pytorch_ckpt_path):
        import torch

        pt_weights = torch.load(pytorch_ckpt_path, map_location="cpu")
        for module_name in ["text_encoder", "diffusion_model", "decoder", "encoder"]:
            module_weights = []
            for i, (key, perm) in enumerate(PYTORCH_CKPT_MAPPING[module_name]):
                w = pt_weights["state_dict"][key].numpy()
                if perm is not None:
                    w = np.transpose(w, perm)
                module_weights.append(w)
            getattr(self, module_name).set_weights(module_weights)
            print("Loaded %d weights for %s" % (len(module_weights), module_name))



In [24]:

def get_models(img_height, img_width, download_weights=True):
    n_h = img_height // 8
    n_w = img_width // 8

    # Create text encoder
    input_word_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32")
    input_pos_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32")
    embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids])
    text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds)

    # Creation diffusion UNet
    context = keras.layers.Input((MAX_TEXT_LEN, 768))
    t_emb = keras.layers.Input((320,))
    latent = keras.layers.Input((n_h, n_w, 4))
    unet = UNetModel()
    diffusion_model = keras.models.Model(
        [latent, t_emb, context], unet([latent, t_emb, context])
    )

    # Create decoder
    latent = keras.layers.Input((n_h, n_w, 4))
    decoder = Decoder()
    decoder = keras.models.Model(latent, decoder(latent))

    inp_img = keras.layers.Input((img_height, img_width, 3))
    encoder = Encoder()
    encoder = keras.models.Model(inp_img, encoder(inp_img))

    if download_weights:
        text_encoder_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5",
            file_hash="d7805118aeb156fc1d39e38a9a082b05501e2af8c8fbdc1753c9cb85212d6619",
        )
        diffusion_model_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/diffusion_model.h5",
            file_hash="a5b2eea58365b18b40caee689a2e5d00f4c31dbcb4e1d58a9cf1071f55bbbd3a",
        )
        decoder_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/decoder.h5",
            file_hash="6d3c5ba91d5cc2b134da881aaa157b2d2adc648e5625560e3ed199561d0e39d5",
        )

        encoder_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/encoder_newW.h5",
            file_hash="56a2578423c640746c5e90c0a789b9b11481f47497f817e65b44a1a5538af754",
        )

        text_encoder.load_weights(text_encoder_weights_fpath)
        diffusion_model.load_weights(diffusion_model_weights_fpath)
        decoder.load_weights(decoder_weights_fpath)
        encoder.load_weights(encoder_weights_fpath)
    return text_encoder, diffusion_model, decoder, encoder


#Prompting

In [37]:
generator = StableDiffusion(img_height=512, img_width=512, jit_compile=False)
img = generator.generate(
    "a painting of a virus monster playing guitar",
    negative_prompt="",
    num_steps=50,
    unconditional_guidance_scale=7.5,
    temperature=1,
    batch_size=1,
    # seed=args.seed,
)

 49 981:   0%|          | 0/50 [00:21<?, ?it/s]


ResourceExhaustedError: ignored

In [None]:
pnginfo = PngInfo()
output_image_name = "x.png"
Image.fromarray(img[0]).save(output_image_name, pnginfo=pnginfo)
print(f"saved at {output_image_name}")
