In [1]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Text Process

In [43]:
text_ids = [[49406, 47124, 15144,   267, 32515,   267,  1033,  7425,   267,  5860,
            267,  9680,   267, 15567, 24190,   267, 21154,   267,  6687,   318,
           3940,   267,   534,  1863,   746,   267,  2660,   268,   705,   267,
           1774,   268,  3940,   267,    67,  1892,   267, 14531,   267,  7681,
            268,  3940,   267, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407]]
input_ids = torch.tensor(text_ids)
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, 77))

In [9]:
tokenizer_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)

In [None]:
text_encoder_name = "openai/clip-vit-large-patch14"
text_encoder = CLIPTextModel.from_pretrained(text_encoder_name)

In [17]:
model_max_length = tokenizer.model_max_length
max_token_length = 255

In [28]:
def get_hidden_states(max_token_length, input_ids, tokenizer, text_encoder, weight_dtype=None):
    b_size = input_ids.size()[0]
    input_ids = input_ids.reshape((-1, tokenizer.model_max_length))
    
    encoder_hidden_states = text_encoder(input_ids)[0]
    encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
    
    states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
    # 这里循环的意思是每次跳过77个字符，77是分词模型的最大长度
    for i in range(1, max_token_length, tokenizer.model_max_length):
        states_list.append(encoder_hidden_states[:, i:i+tokenizer.model_max_length - 2])
    states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
    encoder_hidden_states = torch.cat(states_list, dim=1)
    return encoder_hidden_states

In [44]:
encoder_hidden_states = get_hidden_states(max_token_length, input_ids, tokenizer, text_encoder)

In [53]:
# Vae Encode

In [54]:
from diffusers import AutoencoderKL

In [55]:
def create_vae_diffusers_config():
    block_out_channels = [128, 256, 512, 512]
    down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
    up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
    
    config = dict(
        sample_size=256,
        in_channels=3,
        out_channels=3,
        down_block_types=tuple(down_block_types),
        up_block_types=tuple(up_block_types),
        block_out_channels=3,
        latent_channels=4,
        layers_per_block=2
    )
    return config

In [None]:
vae_config = create_vae_diffusers_config()
#converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
vae = AutoencoderKL(**vae_config)
#vae.load_state_dict_stat(converted_vae_checkpoint)

In [31]:
# Noise Process

In [48]:
from diffusers import DDPMScheduler

In [49]:
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False)

In [51]:
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,))
timesteps = timesteps.long()

In [None]:
# Unet

In [None]:
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
def create_unet_diffusers_config():
    
    block_out_channels = [320, 640, 1280, 1280]
    
    down_block_types = []
    resolution = 1
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
        down_block_types.append(block_type)
        if i != len(block_out_channels) - 1:
            resolution *= 2
    
    up_block_types = []
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
        up_block_types.append(block_type)
        resolution //= 2