In [None]:
import os
import gc
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from torch.optim import AdamW
from transformers import CLIPTextModel, CLIPTokenizer, set_seed
from accelerate import Accelerator
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from accelerate.utils import ProjectConfiguration
import numpy as np 
from diffusers.optimization import get_scheduler

# Configuration
output_dir = "text-inversion-model"
image_path = "1.png"  # Path to the single image
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
num_vectors = 1
placeholder_token = ["*"]
resolution = 512
train_batch_size = 1
max_train_steps = 5000
gradient_accumulation_steps = 1
learning_rate = 0.0001
seed = 42

# Set seed for reproducibility
set_seed(seed)

# Initialize the accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# Load models
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", cache_dir="model")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", cache_dir="model")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", cache_dir="model")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", cache_dir="model")
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", cache_dir="model")

# Move DDPMScheduler tensors to GPU
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
# noise_scheduler.alphas_cumprod_prev = noise_scheduler.alphas_cumprod_prev.to(accelerator.device)
noise_scheduler.betas = noise_scheduler.betas.to(accelerator.device)

# Prepare tokenizer and add placeholder tokens
token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer.convert_tokens_to_ids([initializer_token_id])
print(placeholder_token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data

with torch.no_grad():
    for token_id in placeholder_token_ids:
        token_embeds[token_id] = token_embeds[initializer_token_id].clone()

# Freeze all parameters except for the token embeddings in text encoder
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

unet.train()
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()

# Load and preprocess the image
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
    transforms.Resize((resolution, resolution)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
image_tensor = preprocess(image).unsqueeze(0).to(accelerator.device)

prompt = "a photo of a "
text_input = tokenizer(prompt, return_tensors="pt")
input_index = len(text_input.input_ids[0])
print(input_index)
text_input = tokenizer(prompt, return_tensors="pt").to(accelerator.device)

text_encoder.to(accelerator.device)

with torch.no_grad():
  text_embeddings_prompt = text_encoder(text_input.input_ids)[0]

text_embeddings = torch.randn((1, input_index + 3, 768), dtype=torch.float32, device=accelerator.device, requires_grad=True)
text_embeddings.data[:,:input_index,:] = text_embeddings_prompt
print(text_embeddings.shape)

# Initialize text embeddings with requires_grad=True
#text_embeddings = torch.randn((1, 1, 768), dtype = torch.float32, device=accelerator.device, requires_grad=True)
# text_embeddings_ssf = torch.load(f'tensor_without_grad_final-lr0.001.pt').to(accelerator.device)
text_embeddings_ssf = torch.load(f'1720587980.2415967_average_tensor.pt').to(accelerator.device)

text_embeddings.data[:] = text_embeddings_ssf

# Initialize the optimizer
optimizer = torch.optim.AdamW(
    [text_embeddings],
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    eps=1e-8
)

# lr_scheduler = get_scheduler(
#     "cosine",
#     optimizer=optimizer,
#     num_warmup_steps=50,
#     num_training_steps=max_train_steps,
#     num_cycles=1,
# )

lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=80
)

# Prepare models and optimizer with the accelerator
text_embeddings, optimizer, lr_scheduler = accelerator.prepare(text_embeddings, optimizer, lr_scheduler)

# Set weight data type based on mixed precision
weight_dtype = torch.float32
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# Initialize progress bar and original embeddings
progress_bar = tqdm(range(max_train_steps), desc="Steps", disable=not accelerator.is_local_main_process)
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
losses = []

for t in progress_bar:
    # Convert image to latent space
    target_latents = vae.encode(image_tensor.to(dtype=weight_dtype)).latent_dist.sample().detach()
    target_latents = target_latents * vae.config.scaling_factor

    with accelerator.accumulate(text_embeddings):
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(target_latents)
        bsz = target_latents.shape[0]
        # Sample a random timestep for each image
        timestep = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=target_latents.device)
        timestep = timestep.long()
        
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(target_latents, noise, timestep)
              
        # noise = noisy_latents - target_latents
        
        noise_pred = unet(noisy_latents, timestep, encoder_hidden_states=text_embeddings).sample
        
        # loss = F.mse_loss(latents_pred.float(), target_latents.float(), reduction="mean")
        loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")

        accelerator.backward(loss)
      
        with torch.no_grad():
            text_embeddings.grad[:, :input_index,:] = 0  # Zero out, prompt grads

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    losses.append(loss.detach().item())
    logs = {"lr": lr_scheduler.get_last_lr()[0], "avg" : np.mean(losses[-100:])}
    
    progress_bar.set_postfix(**logs)
    
    if len(losses) % 50 == 0:
        text_embeddings_without_grad = text_embeddings.detach()
        
        import time
        now = time.time()  # Current time

        torch.save(text_embeddings_without_grad, os.path.join("2024", str(now) + f'_tensor_without_grad-{len(losses)}-{np.mean(losses[-100:])}.pt'))

    if len(losses) % 500 == 0:
        torch.save(text_embeddings_without_grad, f'tensor_without_grad-{len(losses)}.pt')

text_embeddings_without_grad = text_embeddings.detach()
torch.save(text_embeddings_without_grad, f'tensor_without_grad_final-lr{learning_rate}.pt')

# Clean up
torch.cuda.empty_cache()
gc.collect()

accelerator.end_training()
