<a href="https://colab.research.google.com/github/soumik12345/diffusion_prompt_upsampling/blob/main/generate_and_validate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prompt Upsampling for Diffusion Models

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)]()

[![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/soumik12345/diffusion_prompt_upsampling)

Prompting for current generation of text-to-image diffusion models is extremely brittle (such as Stable Diffusion XL), i.e, its difficult to create an optimal prompting strategy to reliably generate images of a certain quality and sometimes even reliably following the prompt to generate the image.

This project aims to implement the prompt upsampling strategy as mentioned in the technical report from OpenAI accompanying DALL-E 3; [Improving Image Generation with Better Captions](https://cdn.openai.com/papers/dall-e-3.pdf). This prompting strategy helps us improve the quality of generated images even while using baseline prompts.

The repository using [DSPy](https://dspy-docs.vercel.app) + GPT4 for prompt upsampling and [Weave](https://wandb.me/weave) to trace and evaluate the workflow.

## Installation

In [1]:
!git clone https://github.com/soumik12345/diffusion_prompt_upsampling

Cloning into 'diffusion_prompt_upsampling'...
remote: Enumerating objects: 82, done.[K
remote: Counting objects: 100% (82/82), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 82 (delta 44), reused 59 (delta 25), pack-reused 0[K
Receiving objects: 100% (82/82), 142.05 KiB | 15.78 MiB/s, done.
Resolving deltas: 100% (44/44), done.


In [2]:
%cd diffusion_prompt_upsampling
!pip install -qe .

/content/diffusion_prompt_upsampling
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m74.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.7/280.7 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.0/337.0 kB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
import os
from typing import Optional

import rich
import weave
from diffusion_prompt_upsampling.diffusion_model import StableDiffusionXLModel
from diffusion_prompt_upsampling.judge_model import OpenAIJudgeModel
from getpass import getpass

  deprecate("VQEncoderOutput", "0.31", deprecation_message)
  deprecate("VQModel", "0.31", deprecation_message)


## Initializing Weave

DSPy and OpenAI SDK are already integrated with Weave, all you need to do is include `weave.init()` at the start of your code, and everything is traced automatically!

In [2]:
weave_project_name = "diffusion-prompt-upsample" # @param {type:"string"}

weave.init(project_name=weave_project_name)

Logged in as Weights & Biases user: geekyrakshit.
View Weave data at https://wandb.ai/geekyrakshit/diffusion-prompt-upsample/weave


<weave.weave_client.WeaveClient at 0x7a63d402be80>

## Initialize Stable Diffusion XL

The [`diffusion_prompt_upsampling.diffusion_model.StableDiffusionXLModel`](https://github.com/soumik12345/diffusion_prompt_upsampling/blob/main/diffusion_prompt_upsampling/diffusion_model.py#L22) has been implemented as a [`weave.Model`](https://wandb.github.io/weave/guides/core-types/models) which enables us to automatically track the calls for this class and version the code as an object on Weave.

Note that the prompt upsampling feature is dependent on GPT-4o and the evaluation model uses an OpenAI multi-modal LLM. Hence, you will need to provide an OpenAI API Key.

In [3]:
openai_api_key = getpass("Enter OpenAI API Key:")
os.environ["OPENAI_API_KEY"] = openai_api_key

Enter OpenAI API Key:··········


In [4]:
model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" # @param {type:"string"}
enable_cpu_offfload = True # @param {type:"boolean"}
upsample_prompt = True # @param {type:"boolean"}
upsample_prompt = True # @param {type:"boolean"}
use_stock_negative_prompt = True # @param {type:"boolean"}

diffusion_model = StableDiffusionXLModel(
    model_name_or_path=model_name_or_path,
    enable_cpu_offfload=enable_cpu_offfload,
    upsample_prompt=upsample_prompt,
    use_stock_negative_prompt=use_stock_negative_prompt,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

## Initalize the Judge Model

The [`diffusion_prompt_upsampling.judge_model`](https://github.com/soumik12345/diffusion_prompt_upsampling/blob/main/diffusion_prompt_upsampling/judge_model.py) also has been implemented as a [`weave.Model`](https://wandb.github.io/weave/guides/core-types/models). Its purpose is to evaluate how closely the generated image follows the base prompt.

In [5]:
openai_model = "gpt-4-turbo" # @param ["gpt-4-turbo", "gpt-4o", "gpt-4o-mini"]
judge_model_seed = 42 # @param {type:"integer"}

judge_model = OpenAIJudgeModel(openai_model=openai_model, seed=judge_model_seed)

## Generate and validate the Image

We now run the diffusion model using a base prompt and evaluate this image using the judge model.

Running the following cell will show you links to the respective weave traces which would enable you to explore the results and dive deeper into the underlying mechanisms of the diffusion model and the LLM judge using a rich and interactive UI.

In [6]:
base_prompt = "A frog dressed as a knight" # @param {type:"string"}

image = diffusion_model.predict(base_prompt=base_prompt)["image"]
judgement = judge_model.predict(base_prompt=base_prompt, generated_image=image)

rich.print(f"{judgement=}")

  0%|          | 0/50 [00:00<?, ?it/s]

🍩 https://wandb.ai/geekyrakshit/diffusion-prompt-upsample/r/call/25c6e0ec-5eaf-4796-a4e8-055fc9163170


ValueError: Class 'JudgeMent' not found in the notebook