# Prompt Upsampling for Diffusion Models

<a href="https://colab.research.google.com/github/soumik12345/diffusion_prompt_upsampling/blob/main/notebooks/generate_and_validate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/soumik12345/diffusion_prompt_upsampling)

Prompting for current generation of text-to-image diffusion models is extremely brittle (such as Stable Diffusion XL), i.e, its difficult to create an optimal prompting strategy to reliably generate images of a certain quality and sometimes even reliably following the prompt to generate the image.

This project aims to implement the prompt upsampling strategy as mentioned in the technical report from OpenAI accompanying DALL-E 3; [Improving Image Generation with Better Captions](https://cdn.openai.com/papers/dall-e-3.pdf). This prompting strategy helps us improve the quality of generated images even while using baseline prompts.

The repository using [DSPy](https://dspy-docs.vercel.app) + GPT4 for prompt upsampling and [Weave](https://wandb.me/weave) to trace and evaluate the workflow.

![](../assets/sample_generation_trace.gif)

## Installation

In [None]:
!git clone https://github.com/soumik12345/diffusion_prompt_upsampling

In [None]:
%cd diffusion_prompt_upsampling
!pip install -qe .

In [None]:
import os
from typing import Optional

import rich
import weave
from diffusion_prompt_upsampling.diffusion_model import StableDiffusionXLModel
from diffusion_prompt_upsampling.judge_model import OpenAIJudgeModel
from getpass import getpass

## Initializing Weave

DSPy and OpenAI SDK are already integrated with Weave, all you need to do is include `weave.init()` at the start of your code, and everything is traced automatically!

In [None]:
weave_project_name = "diffusion-prompt-upsample" # @param {type:"string"}

weave.init(project_name=weave_project_name)

## Initialize Stable Diffusion XL

The [`diffusion_prompt_upsampling.diffusion_model.StableDiffusionXLModel`](https://github.com/soumik12345/diffusion_prompt_upsampling/blob/main/diffusion_prompt_upsampling/diffusion_model.py#L22) has been implemented as a [`weave.Model`](https://wandb.github.io/weave/guides/core-types/models) which enables us to automatically track the calls for this class and version the code as an object on Weave.

Note that the prompt upsampling feature is dependent on GPT-4o and the evaluation model uses an OpenAI multi-modal LLM. Hence, you will need to provide an OpenAI API Key.

In [None]:
openai_api_key = getpass("Enter OpenAI API Key:")
os.environ["OPENAI_API_KEY"] = openai_api_key

In [None]:
model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" # @param {type:"string"}
enable_cpu_offfload = True # @param {type:"boolean"}
upsample_prompt = True # @param {type:"boolean"}
upsample_prompt = True # @param {type:"boolean"}
use_stock_negative_prompt = True # @param {type:"boolean"}

diffusion_model = StableDiffusionXLModel(
    model_name_or_path=model_name_or_path,
    enable_cpu_offfload=enable_cpu_offfload,
    upsample_prompt=upsample_prompt,
    use_stock_negative_prompt=use_stock_negative_prompt,
)

## Initalize the Judge Model

The [`diffusion_prompt_upsampling.judge_model`](https://github.com/soumik12345/diffusion_prompt_upsampling/blob/main/diffusion_prompt_upsampling/judge_model.py) also has been implemented as a [`weave.Model`](https://wandb.github.io/weave/guides/core-types/models). Its purpose is to evaluate how closely the generated image follows the base prompt.

In [None]:
openai_model = "gpt-4-turbo" # @param ["gpt-4-turbo", "gpt-4o", "gpt-4o-mini"]
judge_model_seed = 42 # @param {type:"integer"}

judge_model = OpenAIJudgeModel(openai_model=openai_model, seed=judge_model_seed)

## Generate and validate the Image

We now run the diffusion model using a base prompt and evaluate this image using the judge model.

Running the following cell will show you links to the respective weave traces which would enable you to explore the results and dive deeper into the underlying mechanisms of the diffusion model and the LLM judge using a rich and interactive UI.

In [None]:
base_prompt = "A frog dressed as a knight" # @param {type:"string"}

image = diffusion_model.predict(base_prompt=base_prompt)["image"]
judgement = judge_model.predict(base_prompt=base_prompt, generated_image=image)

rich.print(f"{judgement=}")

Here's how an evaluation trace looks on Weave 👇

![](../assets/sample_evaluation_trace.gif)