## Libraries

In [None]:
# %pip install --upgrade diffusers transformers scipy accelerate datasets peft
# %pip install gradio

In [None]:
# General Imports
import os, gc
from tqdm.auto import tqdm
import numpy as np
import torch
from PIL import Image

# Clustering Imports
from transformers import CLIPProcessor, CLIPModel
from transformers import CLIPTextModel, CLIPTokenizer
# from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

# Generation Imports
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import CLIPTextModel
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from datasets import Dataset
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import math
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import (convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft)
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.utils.import_utils import is_xformers_available

# Gradio Imports
import gradio as gr

## Pipeline

In [None]:
'''
Structure of the pipline to generate, find clusters and update model
'''

class PrefrenceDiffusion():
  def __init__(self, model_path="CompVis/stable-diffusion-v1-4", output_dir="./"):
    self.model_name = model_path
    self.output_dir = output_dir
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    # Can add more inputs based on the requirements...
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    self.clip_model = clip_model.to(self.device)
    self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    if self.output_dir != './':
      os.makedirs(self.output_dir, exist_ok=True)

    os.makedirs(self.output_dir + "/checkpoints", exist_ok=True)

    self.beta_dpo = 2500
    self.accelerator = Accelerator(
        gradient_accumulation_steps=2,
        mixed_precision="fp16",
        project_config=ProjectConfiguration(project_dir=self.output_dir),
    )
    self.seed = None

  def generate_images(self, prompt, num_imgs=100, save_imgs=False, with_dpo=False, dpo_file_name="lora_dpo_weights.safetensors"):
    '''
    Inputs:
    - prompt: Prompt to generate images ex: [a photo of a dog]
    - num_imgs: No of images to generate
    - save_imgs: Saves the images to a folder inside output directory
    - with_dpo: Boolean parameter to decide whether to use dpo updated model or not.
    - dpo_file_name: Name of the file to be loaded while using DPO. Default name is lora_dpo_weights.safetensors.

    Output:
    - images: List of PIL images based on prompts (len(images) == len(prompts))
    '''
    weight_dtype = torch.float32
    if self.accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif self.accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    prompts = prompt * num_imgs

    pipeline = DiffusionPipeline.from_pretrained(self.model_name, torch_dtype=weight_dtype)

    if with_dpo:
        pipeline.load_lora_weights(self.output_dir, weight_name=dpo_file_name)

    if save_imgs:
        save_dir = f"images/{prompts[0].replace(' ', '_')}"
        save_path = os.join(self.output_dir, save_dir)
        os.makedirs(save_path, exist_ok=True)

    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline = pipeline.to(self.accelerator.device, dtype=weight_dtype)
    if is_xformers_available():
        pipeline.enable_xformers_memory_efficient_attention()

    pipeline.safety_checker = None

    images = []
    for i, p in enumerate(prompts):
      img = pipeline(p, num_inference_steps=25).images[0]
      images.append(img)

      if save_imgs:
          img.save(f"{save_path}/image_{i+1}.png")

    # Free space
    self.accelerator.free_memory()
    del pipeline
    torch.cuda.empty_cache()
    gc.collect()

    return images

  def get_clip_embeddings(self, images):
    """
    This function takes a list of images as input,
    returning a list of clip embeddings of the images.

    :param images: The input image as a string.
    :return: A list of clip embeddings.
    """
    inputs = self.clip_processor(images=images, return_tensors="pt", padding=True).to(self.device)

    with torch.no_grad():
        image_features = self.clip_model.get_image_features(**inputs)

    return image_features.cpu()

  def get_embedding_from_images(self, images):
    '''
    Inputs:
    - images: List of PIL images on which we need to perform GMM clustering

    Output:
    - img_embeddings: CLIP img embeddings
    '''
    embeddings = []
    for img in images:
        embeddings.append(self.get_clip_embeddings(img))
    return embeddings

  def find_prefrences(self, images, n_clusters=10):
    '''
    Inputs:
    - images: List of PIL images on which we need to perform GMM clustering
    - n_clusters: Number of clusters we need to split the input images. Default value is 10.

    Output:
    - prefrence_map: A map of cluster and PIL images. The key of the map indicates the cluster group and values indicates the list of PIL images.
    '''
    # Number of clusters
    n_components = n_clusters
    gmm = GaussianMixture(n_components=n_components, random_state=0)

    standardard_deviation_info = {
       "original": 0,
       "clusters": [0] * n_components
    }

    img_embeddings = self.get_embedding_from_images(images)

    img_embeddings = np.array([emb.cpu().numpy() for emb in img_embeddings]).squeeze()

    # Calculate the standard deviation of the original embeddings
    original_std = np.std(img_embeddings, axis=0)
    # Taking norm of the original standard deviation
    standardard_deviation_info["original"] = np.linalg.norm(original_std)

    gmm.fit(img_embeddings)
    labels = gmm.predict(img_embeddings)

    # Grouping image names by their clusters
    clustered_images = {i: [] for i in range(n_components)}
    for label, image in zip(labels, images):
        clustered_images[label].append(image)
    
    # Calculate the standard deviation of the clusters
    for i in range(n_components):
        cluster_embeddings = np.array([emb.cpu().numpy() for emb in self.get_embedding_from_images(clustered_images[i])]).squeeze()
        cluster_std = np.std(cluster_embeddings, axis=0)
        standardard_deviation_info["clusters"][i] = np.linalg.norm(cluster_std)

    return clustered_images, standardard_deviation_info




  def tokenize_text(self, tokenizer, prompts):
    max_length = tokenizer.model_max_length
    text_inputs = tokenizer(prompts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
    return text_inputs.input_ids

  @torch.no_grad()
  def encode_prompt(self, text_encoder, input_ids):
    text_encoder_ids = input_ids.to(text_encoder.device)
    attention_mask = None
    prompt_embeds = text_encoder(text_encoder_ids, attention_mask=attention_mask)[0]
    return prompt_embeds

  def unwrap_model(self, model):
    model = self.accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

  def save_model_hook(self, models, weights, output_dir):
    if self.accelerator.is_main_process:
      unet_lora_layers_to_save = None

      for model in models:
        if isinstance(model, type(self.unwrap_model(self.unet))):
            unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
        else:
            raise ValueError("Unexpected save model: ", {model.__class__})
        weights.pop()

      LoraLoaderMixin.save_lora_weights(
          output_dir,
          unet_lora_layers=unet_lora_layers_to_save,
          text_encoder_lora_layers=None,
      )

  def inject_lora_into_unet(self, state_dict, unet, network_alphas=None, adapter_name="default"):
    keys = list(state_dict.keys())
    unet_keys = [k for k in keys if k.startswith("unet.")]
    state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k in unet_keys}
    state_dict = convert_unet_state_dict_to_peft(state_dict)

    if network_alphas is not None:
      alpha_keys = [k for k in network_alphas.keys() if k.startswith("unet")]
      network_alphas = {k.replace("unet.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
      network_alphas = convert_unet_state_dict_to_peft(network_alphas)

    set_peft_model_state_dict(unet, state_dict, adapter_name)
    unet.load_attn_procs(state_dict, network_alphas=network_alphas)

  def load_model_hook(self, models, input_dir):
    unet_ = None

    while len(models) > 0:
      model = models.pop()
      if isinstance(model, type(self.unwrap_model(self.unet))):
        unet_ = model
      else:
        raise ValueError("Unexpected load model: ", {model.__class__})

    lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
    self.inject_lora_into_unet(lora_state_dict, unet_, network_alphas=network_alphas)

    models = [unet_]
    for model in models:
      for p in model.parameters():
        # Upcast trainable parameters (LoRA) into fp32
        if p.requires_grad:
          p.data = p.data.to(torch.float32)

  def update_model(self, dpo_preferred_images, dpo_unpreferred_images, prompt, dpo_epochs=5, continue_training=False, output_file_name="lora_dpo_weights.safetensors"):
    '''
    Inputs:
    - dpo_preferred_images: List of PIL images of preferred cluster
    - dpo_unpreferred_images: List of PIL images of unpreferred clusters. Note that len(dpo_unpreferred_images) == len(dpo_preferred_images).
    - prompts: Prompt to generate images ex: [a photo of a dog]
    - dpo_epochs: Number of epochs the DPO training should run. Default value is 5.
    - continue_training: Boolean parameter to continue training from where it was left. Default value is False. If used in iterative manner than the dpo_epochs should represent
                         total number of steps for that iteration, for example in the two iterations, the first iteration dpo_epochs is set at 5 but in the second it should be
                         set at 10 to continue the training.
    - output_file_name: File name of the DPO weights. Default name is lora_dpo_weights.safetensors.

    Output:
    - No output. Weights are saved in the root folder.
    '''
    tokenizer = AutoTokenizer.from_pretrained(self.model_name, subfolder="tokenizer", use_fast=False)
    text_encoder_config = PretrainedConfig.from_pretrained(self.model_name, subfolder="text_encoder")
    text_encoder_class = CLIPTextModel if text_encoder_config.architectures[0] == "CLIPTextModel" else ValueError("Model not supported")

    # Load scheduler and models
    noise_scheduler = DDPMScheduler.from_pretrained(self.model_name, subfolder="scheduler")
    text_encoder = text_encoder_class.from_pretrained(self.model_name, subfolder="text_encoder")
    vae = AutoencoderKL.from_pretrained(self.model_name, subfolder="vae")
    self.unet = UNet2DConditionModel.from_pretrained(self.model_name, subfolder="unet")

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    self.unet.requires_grad_(False)

    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if self.accelerator.mixed_precision == "fp16":
      weight_dtype = torch.float16
    elif self.accelerator.mixed_precision == "bf16":
      weight_dtype = torch.bfloat16

    self.unet.to(self.accelerator.device, dtype=weight_dtype)
    vae.to(self.accelerator.device, dtype=weight_dtype)
    text_encoder.to(self.accelerator.device, dtype=weight_dtype)

    # Setup LoRA
    unet_lora_config = LoraConfig(
      r = 8,
      lora_alpha=8,
      init_lora_weights="gaussian",
      target_modules=["to_k", "to_q", "to_v", "to_out.0"]
    )
    self.unet.add_adapter(unet_lora_config)
    # Upcast trainable parameters (LoRA) into fp32
    for p in self.unet.parameters():
      if p.requires_grad:
        p.data = p.data.to(torch.float32)

    # Multi-GPU training mode
    self.unet.enable_gradient_checkpointing()

    # # Efficient training
    if is_xformers_available():
      self.unet.enable_xformers_memory_efficient_attention()

    self.accelerator.register_save_state_pre_hook(self.save_model_hook)
    self.accelerator.register_load_state_pre_hook(self.load_model_hook)

    # Optimizer creation
    optimizer_class = torch.optim.AdamW
    params_to_optimize = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
    # Values from the huggingface github page
    optimizer = optimizer_class(
      params_to_optimize,
      lr=1e-5,
      betas=(0.9, 0.999),
      weight_decay=1e-2,
      eps=1e-08
    )

    # Dataset and dataloader creation
    data_dict = {
      "prompt": prompt * len(dpo_preferred_images),
      "preferred": dpo_preferred_images, # jpg_0
      "unpreferred": dpo_unpreferred_images, # jpg_1
      "label_0": [1] * len(dpo_preferred_images),
      "label_1": [0] * len(dpo_unpreferred_images)
    }

    train_dataset = Dataset.from_dict(data_dict)
    train_transform = transforms.Compose([
        transforms.Resize(int(512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.RandomCrop(512),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    def preprocess_train_data(data):
      all_pixel_values = []

      for key in ["preferred", "unpreferred"]:
          images = [img if isinstance(img, Image.Image) else Image.open(img).convert("RGB") for img in data[key]]
          pixel_values = [train_transform(image) for image in images]
          all_pixel_values.append(pixel_values)

      # Double the channel dimentions, preferred then unpreferred
      im_tup_iterator = zip(*all_pixel_values)
      combined_pixel_values = []

      for im_tup, label_0 in zip(im_tup_iterator, data["label_0"]):
          if label_0 == 0:
              im_tup = im_tup[::-1]
          combined_im = torch.cat(im_tup, dim=0)
          combined_pixel_values.append(combined_im)
      data["pixel_values"] = combined_pixel_values
      data["input_ids"] = self.tokenize_text(tokenizer, data["prompt"])
      return data

    with self.accelerator.main_process_first():
      train_dataset = train_dataset.shuffle(seed=self.seed)
      train_dataset = train_dataset.with_transform(preprocess_train_data)

    def collate_fn(data):
      pixel_values = torch.stack([x["pixel_values"] for x in data])
      pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
      final_dict = {"pixel_values": pixel_values}
      final_dict["input_ids"] = torch.stack([x["input_ids"] for x in data])
      return final_dict

    train_dataloader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=1, # Might need to change this
      shuffle=True,
      collate_fn=collate_fn
    )

    # num_gradient_update = len(dpo_preferred_images)//2
    # if num_gradient_update == 0:
    num_gradient_update = 2

    # Training steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / num_gradient_update)
    dpo_epochs = dpo_epochs
    max_train_steps = dpo_epochs * num_update_steps_per_epoch

    # Scheduler
    lr_scheduler = get_scheduler(
        "constant", # Type of scheduler
        optimizer=optimizer,
        num_warmup_steps= 0 * self.accelerator.num_processes,
        num_training_steps = max_train_steps * self.accelerator.num_processes,
        num_cycles=1,
        power=1.0
    )
    self.unet, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare(
        self.unet, optimizer, train_dataloader, lr_scheduler
    )

    # Recalculate the training steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / num_gradient_update)
    dpo_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Train!
    global_step = 0
    first_epoch = 0

    # Loading weights for continued training
    if continue_training:
        dirs = os.listdir(self.output_dir + "/checkpoints")
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))
        path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            self.accelerator.print("No checkpoint found, training from scratch")
            initial_global_step = 0
        else:
            self.accelerator.print(f"Loading checkpoint {path}")
            self.accelerator.load_state(os.path.join(self.output_dir + "/checkpoints", path))
            global_step = int(path.split("-")[-1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    else:
        initial_global_step = 0

    progress_bar = tqdm(
        range(0, max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        disable=not self.accelerator.is_local_main_process
    )

    self.unet.train()
    for e in range(first_epoch, dpo_epochs):
        for step, batch in enumerate(train_dataloader):
            with self.accelerator.accumulate(self.unet):
                pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))

                latents = []
                for i in range(0, feed_pixel_values.shape[0], 2): # 2 is VAE batch size
                    latents.append(
                        vae.encode(feed_pixel_values[i:i+2]).latent_dist.sample()
                    )
                latents = torch.cat(latents, dim=0)
                latents = latents * vae.config.scaling_factor

                # Sample noise to be added to the latents
                noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)

                # Sample a random timestep for each image
                bsz = latents.shape[0] // 2
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
                ).repeat(2)

                # Add noise to the model input according to the scheduler
                # (Forward diffusion process)
                noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get text embeddings
                encoder_hidden_states = self.encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1)

                # Predict the noise residuals
                model_pred = self.unet(
                    noisy_model_input,
                    timesteps,
                    encoder_hidden_states
                ).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # Compute losses
                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
                model_losses_w, model_losses_l = model_losses.chunk(2)
                raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
                model_diff = model_losses_w - model_losses_l

                # Reference model prediction
                self.accelerator.unwrap_model(self.unet).disable_adapters()
                with torch.no_grad():
                    ref_preds = self.unet(
                        noisy_model_input,
                        timesteps,
                        encoder_hidden_states
                    ).sample.detach()
                    ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
                    ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))

                    ref_loss_w, ref_loss_l = ref_loss.chunk(2)
                    ref_diff = ref_loss_w - ref_loss_l
                    raw_ref_loss = ref_loss.mean()

                # Re-enable adapters
                self.accelerator.unwrap_model(self.unet).enable_adapters()

                # Final loss
                logits = ref_diff - model_diff
                loss = -1 * F.logsigmoid(self.beta_dpo * logits).mean() # Sigmoid loss

                implicit_acc = (logits > 0).sum().float() / logits.size(0)
                implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)

                # Backward pass
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients:
                    self.accelerator.clip_grad_norm_(params_to_optimize, 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if self.accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

                if self.accelerator.is_main_process:
                    save_path = os.path.join(self.output_dir, f"checkpoints/checkpoint-{global_step}")
                    self.accelerator.save_state(output_dir=save_path)

            logs = {
                "loss": loss.detach().item(),
                "raw_model_loss": raw_model_loss.detach().item(),
                "ref_loss": raw_ref_loss.detach().item(),
                "implicit_acc": implicit_acc.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
            }
            progress_bar.set_postfix(logs)
            self.accelerator.log(logs, step=global_step)

            if global_step >= max_train_steps:
                break

    # Save the LoRA layers
    self.accelerator.wait_for_everyone()
    if self.accelerator.is_main_process:
        self.unet = self.accelerator.unwrap_model(self.unet)
        self.unet = self.unet.to(torch.float32)
        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.unet))

        LoraLoaderMixin.save_lora_weights(
            save_directory=self.output_dir,
            unet_lora_layers=unet_lora_state_dict,
            text_encoder_lora_layers=None,
            weight_name=output_file_name
        )

    self.accelerator.end_training()

    # Free up CUDA memory
    self.accelerator.free_memory()
    del train_dataloader, optimizer, lr_scheduler, self.unet, vae, text_encoder, noise_scheduler
    torch.cuda.empty_cache()
    gc.collect()

## Testing

In [None]:
# pd = PrefrenceDiffusion()

In [None]:
# gen_images = pd.generate_images(prompt = ["a photo of a dog"], num_imgs=20)

In [None]:
# gen_images

In [None]:
# pref_imgs = pd.find_prefrences(gen_images)
# pref_imgs

In [None]:
# pref_imgs[0][0]

In [None]:
# dpo_prefered, dpo_unprefered = get_images(pref_imgs, 4)

In [None]:
# pd.update_model(dpo_prefered, dpo_unprefered, prompt, dpo_epochs=5)

# GRADIO UI


In [None]:
import random
from itertools import cycle

def get_images(preference_imgs, selected_index):
    # Images from the selected index
    selected_images = preference_imgs[selected_index]
    num_images = len(selected_images)

    # Collect all images from other indexes
    all_other_images = []
    for index, images in preference_imgs.items():
        if index != selected_index:
            all_other_images.extend(images)

    # Select a random sample from all collected images, if enough images are available
    if len(all_other_images) >= num_images:
        other_images_sample = random.sample(all_other_images, num_images)
    else:
        # If there aren't enough images to sample from, just use what's available
        cycled_images = list(cycle(all_other_images))
        other_images_sample = random.sample(cycled_images, num_images)

    return selected_images, other_images_sample

In [None]:
global_user_input = None
clustered_imgs = None
standard_deviation_information = {
   "original": 0,
   "clusters": [0] * 10,
   "avg_per_clusters": 0
}

pd = PrefrenceDiffusion()

def display_images_by_cluster(clustered_images, cluster_number, rows=2, cols=5):
    image_list = clustered_images.get(cluster_number, [])
    output = image_list[:rows*cols]
    return output

def update_global_variable_and_generate_images(prompt):
    global global_user_input
    global_user_input = prompt

    # CALL GEN IMAGE & GMM HERE
    global clustered_imgs
    gen_images = pd.generate_images(prompt = [global_user_input], num_imgs=100)
    clustered_imgs, std_info = pd.find_prefrences(gen_images)

    # Update standard deviation information
    global standard_deviation_information
    standard_deviation_information["original"] = std_info["original"]
    standard_deviation_information["clusters"] = std_info["clusters"]
    standard_deviation_information["avg_per_clusters"] = sum(std_info["clusters"]) / len(std_info["clusters"])

    print(standard_deviation_information)

    return (standard_deviation_information, list(clustered_imgs.keys()))

def show_images(cluster_number):
    images = display_images_by_cluster(clustered_imgs, cluster_number)
    return images

def process_images(cluster_number):
    dpo_prefered, dpo_unprefered = get_images(clustered_imgs, cluster_number)

    # CALL DPO HERE
    pd.update_model(dpo_prefered, dpo_unprefered, [global_user_input], dpo_epochs=20)

    # Gen image
    gen_updated_images = pd.generate_images(prompt = [global_user_input], num_imgs=5, with_dpo=True)

    # result_image = [clustered_imgs[0][0]]
    result_image = gen_updated_images

    return result_image


# Create Blocks interface
blocks = gr.Blocks()

with blocks:
    gr.Markdown("## Prefrence Diffusion")

    with gr.Row():
        input_text = gr.Textbox(label="Enter a prompt for image generation")
        submit_button = gr.Button("Generate Images")

    dropdown = gr.Dropdown(choices=[i for i in range(10)], label="Select Cluster")

    with gr.Row():
      with gr.Column(scale=1):
        show_button = gr.Button("Show Images")
        gallery = gr.Gallery(label="Images")
      with gr.Column(scale=1):
        process_button = gr.Button("Call DPO")
        result_gallery = gr.Gallery(label="DPO Results")

    result_textbox = gr.Textbox(label="Processing Result")

    submit_button.click(fn=update_global_variable_and_generate_images, inputs=input_text, outputs=[result_textbox, dropdown])
    show_button.click(fn=show_images, inputs=dropdown, outputs=gallery)
    process_button.click(fn=process_images, inputs=dropdown, outputs=result_gallery)


# Launch the interface
blocks.launch(debug=True)
