In [None]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer

In [None]:
# Text Process

In [None]:
text_ids = [[49406, 47124, 15144,   267, 32515,   267,  1033,  7425,   267,  5860,
            267,  9680,   267, 15567, 24190,   267, 21154,   267,  6687,   318,
           3940,   267,   534,  1863,   746,   267,  2660,   268,   705,   267,
           1774,   268,  3940,   267,    67,  1892,   267, 14531,   267,  7681,
            268,  3940,   267, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407]]
input_ids = torch.tensor(text_ids)
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, 77))

In [None]:
tokenizer_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)

In [None]:
text_encoder_name = "openai/clip-vit-large-patch14"
text_encoder = CLIPTextModel.from_pretrained(text_encoder_name)

In [None]:
model_max_length = tokenizer.model_max_length
max_token_length = 255

In [None]:
def get_hidden_states(max_token_length, input_ids, tokenizer, text_encoder, weight_dtype=None):
    b_size = input_ids.size()[0]
    input_ids = input_ids.reshape((-1, tokenizer.model_max_length))
    
    encoder_hidden_states = text_encoder(input_ids)[0]
    encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
    
    states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
    # 这里循环的意思是每次跳过77个字符，77是分词模型的最大长度
    for i in range(1, max_token_length, tokenizer.model_max_length):
        states_list.append(encoder_hidden_states[:, i:i+tokenizer.model_max_length - 2])
    states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
    encoder_hidden_states = torch.cat(states_list, dim=1)
    return encoder_hidden_states

In [None]:
# Vae Encode

In [None]:
from diffusers import AutoencoderKL

In [None]:
def create_vae_diffusers_config():
    block_out_channels = [128, 256, 512, 512]
    down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
    up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
    
    config = dict(
        sample_size=256,
        in_channels=3,
        out_channels=3,
        down_block_types=tuple(down_block_types),
        up_block_types=tuple(up_block_types),
        block_out_channels=block_out_channels,
        latent_channels=4,
        layers_per_block=2
    )
    return config

In [None]:
vae_config = create_vae_diffusers_config()
#converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
vae = AutoencoderKL(**vae_config)
#vae.load_state_dict_stat(converted_vae_checkpoint)

In [None]:
# Noise Process

In [None]:
# Unet

In [None]:
from diffusers import UNet2DConditionModel

In [None]:
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
def create_unet_diffusers_config():
    
    block_out_channels = [320, 640, 1280, 1280]
    
    down_block_types = []
    resolution = 1
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
        down_block_types.append(block_type)
        if i != len(block_out_channels) - 1:
            resolution *= 2
    
    up_block_types = []
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
        up_block_types.append(block_type)
        resolution //= 2
    
    config = dict(
        sample_size=64,
        in_channels=4,
        out_channels=4,
        down_block_types=down_block_types,
        up_block_types=up_block_types,
        block_out_channels=block_out_channels,
        layers_per_block=2,
        cross_attention_dim=768,
        attention_head_dim=8
    )
    
    return config

In [None]:
unet_config = create_unet_diffusers_config()
unet = UNet2DConditionModel(**unet_config)

In [None]:
image = torch.randn((1,3,576,576)).to(dtype=torch.float32)
image

In [None]:
latents = vae.encode(image).latent_dist.sample()
# 还没明白是啥意思，好像为了统一方差啥的
latents = latents * 0.18215
b_size = latents.shape[0]

In [None]:
encoder_hidden_states = get_hidden_states(max_token_length, input_ids, tokenizer, text_encoder)

In [None]:
noise = torch.randn_like(latents, device=latents.deivce)

# diffusion-with-offset-noise
# noise_offset = 0.
# noise += noise_offset * torch.((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)



In [None]:
from diffusers import DDPMScheduler

In [None]:
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False)

In [None]:
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,))
timesteps = timesteps.long()

In [None]:
noisy_latents = noise_scheduler.add_noisead(latents, noise, timesteps)

In [None]:
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

In [None]:
a = torch.tensor([3])
a.expand(4)

In [None]:
# vae_config
def create_vae_diffusers_config():
    block_out_channels = [128, 256, 512, 512]
    down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
    up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
    
    config = dict(
        sample_size=256,
        in_channels=3,
        out_channels=3,
        down_block_types=tuple(down_block_types),
        up_block_types=tuple(up_block_types),
        block_out_channels=block_out_channels,
        latent_channels=4,
        layers_per_block=2
    )
    return config


# unet_config
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
def create_unet_diffusers_config():
    
    block_out_channels = [320, 640, 1280, 1280]
    
    down_block_types = []
    resolution = 1
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
        down_block_types.append(block_type)
        if i != len(block_out_channels) - 1:
            resolution *= 2
    
    up_block_types = []
    for i in range(len(block_out_channels)):
        block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
        up_block_types.append(block_type)
        resolution //= 2
    
    config = dict(
        sample_size=64,
        in_channels=4,
        out_channels=4,
        down_block_types=down_block_types,
        up_block_types=up_block_types,
        block_out_channels=block_out_channels,
        layers_per_block=2,
        cross_attention_dim=768,
        attention_head_dim=8
    )
    return config



In [None]:
def get_model(vae_config, unet_config, text_encoder_name):

    #converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
    vae = AutoencoderKL(**vae_config)
    #vae.load_state_dict_stat(converted_vae_checkpoint)
    
    unet = UNet2DConditionModel(**unet_config)
    #unet.load_state_dict_stat()
    
    text_encoder = CLIPTextModel.from_pretrained(text_encoder_name)
    
    
    
    return vae, unet, text_encoder

In [None]:
tokenizer_name = "openai/clip-vit-large-patch14"
vae_config = create_vae_diffusers_config()
unet_config = create_unet_diffusers_config()
text_encoder_name = "openai/clip-vit-large-patch14"

tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)
vae, unet, text_encoder = get_model(vae_config, unet_config, text_encoder_name)

In [None]:
# 1.get latents code from vae
# 2.add noise to latents code
