<a href="https://colab.research.google.com/github/soumik12345/diffusion_prompt_upsampling/blob/main/notebooks/prompt_upsampling_explained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prompt Upsampling for Diffusion Models

The secret sauce to getting good-quality images from text-to-image diffusion models is to provide more control conditions. These models are not really "intelligent," if we don't tell them precisely what we want, they won't be able to generate images with many details. One way to achieve this is to **manually write detailed prompts** that give the diffusion model much more context.

**Prompt-upsampling** is a process that aims to automate the process of writing a detailed prompt using an LLM. The idea is to develop the most barebone idea for a prompt (such as "a man holding a sword") and let a powerful large language model like GPT-4 fill in the prompts with more details, ultimately resulting in a better and more detailed-looking image.

You can read a detailed Weights & Biases report on this technique [here](https://wandb.ai/geekyrakshit/prompt-upsampling-diffusion/reports/Prompt-Upsampling-for-Diffusion-Models--Vmlldzo4OTc3NDc3).


## Installations and Initial Setup

First, you need to install DSPy, Weave and 🧨 Diffusers.

In [None]:
!pip install -qU dspy-ai weave diffusers transformers accelerate

Since we'll be using [OpenAI API](https://openai.com/index/openai-api/) as our LLM Vendor, we will also need an OpenAI API key. You can [sign up](https://platform.openai.com/signup) on the OpenAI platform to get your API key.

In [None]:
import os
from getpass import getpass

api_key = getpass("Enter you OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = api_key

Enter you OpenAI API key: ··········


We write some utility to convert the images into base64 format which would not only let us make multi-modal prompts using the OpenAI API, but also viusalize the images using Weave.

In [None]:
import base64
import io
import re
from pathlib import Path
from typing import Optional, Union

from PIL import Image

EXT_TO_MIMETYPE = {
    ".jpg": "image/jpeg",
    ".png": "image/png",
    ".svg": "image/svg+xml",
}


def base64_encode_image(
    image_path: Union[str, Image.Image], mimetype: Optional[str] = None
) -> str:
    image = Image.open(image_path) if isinstance(image_path, str) else image_path
    mimetype = (
        EXT_TO_MIMETYPE[Path(image_path).suffix]
        if isinstance(image_path, str)
        else "image/png"
    )
    byte_arr = io.BytesIO()
    image.save(byte_arr, format="PNG")
    encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
    encoded_string = f"data:{mimetype};base64,{encoded_string}"
    return str(encoded_string)


def find_base64_images(input_text):
    pattern = r"(data:image/(jpeg|png|svg\+xml);base64,[A-Za-z0-9+/=]+)"
    return [match[0] for match in re.findall(pattern, input_text)]

Next, we enable tracking using [Weave](https://wandb.me/weave). Weave is integrated with DSPy and including weave.init at the start of our code lets us automatically trace our DSPy functions which can be explored in the Weave UI. Check out the [Weave integration docs for DSPy](https://wandb.github.io/weave/guides/integrations/dspy) to learn more.

In [None]:
import weave

weave.init(project_name="prompt-upsampling-diffusion")

Logged in as Weights & Biases user: geekyrakshit.
View Weave data at https://wandb.ai/geekyrakshit/prompt-upsampling-diffusion/weave


<weave.weave_client.WeaveClient at 0x7dde204aea40>

## Implementing Prompt Upsampling using DSPy

[DSPy](https://dspy-docs.vercel.app/) is a framework that pushes building new LM pipelines away from manipulating free-form strings and closer to programming (composing modular operators to build text transformation graphs), where a compiler automatically generates optimized LM invocation strategies and prompts from a program.

According to the DSPy programming model, string-based prompting techniques are first translated into declarative modules with natural-language typed signatures. Then, each module is parameterized to learn its desired behavior by iteratively bootstrapping useful demonstrations within the pipeline.

We're going to use the [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) abstraction to make LLM calls to [GPT4](https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4).


In [None]:
import dspy


upsampler_llm = dspy.OpenAI(
    model="gpt-4",
    system_prompt="""
You are part of a team of bots that creates images. You work with an assistant bot that will draw anything
you say in square brackets. For example, outputting "a beautiful morning in the woods with the sun peaking
through the trees" will trigger your partner bot to output an image of a forest morning, as described.
You will be prompted by people looking to create detailed, amazing images. The way to accomplish this is to
take their short prompts and make them extremely detailed and descriptive.


There are a few rules to follow:
- You will only ever output a single image description per user request.
- Often times, the base prompt might consist of spelling mistakes or grammatical errors. You should correct
    such errors before making them extremely detailed and descriptive.
- Image descriptions must be between 15-80 words. Extra words will be ignored.
""",
)

Next, we create a simple signature specifying the input and output behavior of the prompt-upsampling module.

In [None]:
class PromptUpsamplingSignature(dspy.Signature):
    base_prompt = dspy.InputField()
    answer = dspy.OutputField(
        desc="Create an imaginative image descriptive caption for the given base prompt."
    )

We are going to use [`dspy.MultiChainComparison`](https://dspy-docs.vercel.app/api/modules/MultiChainComparison) Module to execute prompt upsampling. This method aggregates all the student reasoning attempts and calls the `predict` method with extended signatures to get the best reasoning.

In [None]:
reasoning_attemps = [
    dspy.Prediction(
        rationale="a man holding a sword",
        answer="a pale figure with long white hair stands in the center of a dark forest, holding a sword high above his head.",
    ),
    dspy.Prediction(
        rationale="a frog playing dominoes",
        answer="a frog sits on a worn table playing a game of dominoes with an elderly raccoon. the table is covered in a green cloth, and the frog is wearing a jacket and a pair of jeans. The scene is set in a forest, with a large tree in the background.",
    ),
    dspy.Prediction(
        rationale="A bird scaring a scarecrow",
        answer="A large, vibrant bird with an impressive wingspan swoops down from the sky, letting out a piercing call as it approaches a weathered scarecrow in a sunlit field. The scarecrow, dressed in tattered clothing and a straw hat, appears to tremble, almost as if it's coming to life in fear of the approaching bird.",
    ),
    dspy.Prediction(
        rationale="Paying for a quarter-sized pizza with a pizza-sized quarter",
        answer="A person is standing at a pizza counter, holding a gigantic quarter the size of a pizza. The cashier, wide-eyed with astonishment, hands over a tiny, quartersized pizza in return. The background features various pizza toppings and other customers, all of them equally amazed by the unusual transaction.",
    ),
    dspy.Prediction(
        rationale="a quilt with an iron on it",
        answer="a quilt is laid out on a ironing board with an iron resting on top. the quilt has a patchwork design with pastel-colored strips of fabric and floral patterns. the iron is turned on and the tip is resting on top of one of the strips. the quilt appears to be in the process of being pressed, as the steam from the iron is visible on the surface. the quilt has a vintage feel and the colors are yellow, blue, and white, giving it an antique look.",
    ),
    dspy.Prediction(
        rationale="a furry humanoid skunk",
        answer="In a fantastical setting, a highly detailed furry humanoid skunk with piercing eyes confidently poses in a medium shot, wearing an animal hide jacket. The artist has masterfully rendered the character in digital art, capturing the intricate details of fur and clothing texture.",
    ),
    dspy.Prediction(
        rationale="An icy landscape under a starlit sky",
        answer="An icy landscape under a starlit sky, where a magnificent frozen waterfall flows over a cliff. In the center of the scene, a f ire burns bright, its flames seemingly frozen in place, casting a shimmering glow on the surrounding ice and snow.",
    ),
    dspy.Prediction(
        rationale="A fierce garden gnome warrior",
        answer="A fierce garden gnome warrior, clad in armor crafted from leaves and bark, brandishes a tiny sword and shield. He stands valiantly on a rock amidst a blooming garden, surrounded by colorful flowers and towering plants. A determined expression is painted on his face, ready to defend his garden kingdom.",
    ),
    dspy.Prediction(
        rationale="A ferret in a candy jar",
        answer="A mischievous ferret with a playful grin squeezes itself into a large glass jar, surrounded by colorful candy. The jar sits on a wooden table in a cozy kitchen, and warm sunlight filters through a nearby window.",
    ),
    dspy.Prediction(
        rationale="cartoon drawing of an astronaut riding a horse",
        answer="Cartoon drawing of an outer space scene. Amidst floating planets and twinkling stars, a whimsical horse with exaggerated features rides an astronaut, who swims through space with a jetpack, looking a tad overwhelmed.",
    ),
    dspy.Prediction(
        rationale="A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.",
        answer="A small vessel, propelled on water by oars, sails, or an engine, floats gracefully on a serene lake. the sun casts a warm glow on the water, reflecting the vibrant colors of the sky as birds fly overhead.",
    ),
]


In [None]:
prompt_upsampling_module = dspy.MultiChainComparison(
    PromptUpsamplingSignature, M=len(reasoning_attemps)
)

Next, we wrap the prompt upsampling and subsequent image generation calls using [`weave.Model`](https://wandb.github.io/weave/guides/core-types/models/). A Weave Model combines data (including configuration, trained model weights, or other information) and code defining the model's operation. By structuring your code to be compatible with this API, you benefit from a structured way to version your application so you can more systematically keep track of your experiments.

In [None]:
class PromptUpsamplingModel(weave.Model):

    @weave.op()
    def predict(self, base_prompt) -> dict:
        with dspy.context(lm=upsampler_llm):
            return prompt_upsampling_module(
                reasoning_attemps, base_prompt=base_prompt
            ).answer

In [None]:
import torch
from diffusers import AutoPipelineForText2Image, DiffusionPipeline


class StableDiffusionXLModel(weave.Model):
    diffusion_model: str
    enable_cpu_offload: bool = True
    prompt_upsampler: PromptUpsamplingModel
    _pipeline: DiffusionPipeline

    def __init__(
        self,
        diffusion_model: str,
        enable_cpu_offload: bool,
        prompt_upsampler: PromptUpsamplingModel
    ):
        super().__init__(
            diffusion_model=diffusion_model,
            enable_cpu_offload=enable_cpu_offload,
            prompt_upsampler=prompt_upsampler,
        )
        self._pipeline = AutoPipelineForText2Image.from_pretrained(
            self.diffusion_model,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        )
        if self.enable_cpu_offload:
            self._pipeline.enable_model_cpu_offload()
        else:
            self._pipeline = self._pipeline.to("cuda")

    @weave.op()
    def predict(
        self,
        base_prompt: str,
        negative_prompt: Optional[str] = None,
        num_inference_steps: Optional[int] = 50,
        image_size: Optional[int] = 1024,
        guidance_scale: Optional[float] = 7.0,
    ) -> dict:
        upsampled_prompt = self.prompt_upsampler.predict(base_prompt)
        image = self._pipeline(
            prompt=upsampled_prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            height=image_size,
            width=image_size,
            guidance_scale=guidance_scale,
        ).images[0]
        return {
            "upsampled_prompt": upsampled_prompt,
            "image": base64_encode_image(image)
        }

In [None]:
prompt_upsampler=PromptUpsamplingModel()

model = StableDiffusionXLModel(
    diffusion_model="stabilityai/stable-diffusion-xl-base-1.0",
    enable_cpu_offload=True,
    prompt_upsampler=prompt_upsampler,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
sdxl_prediction = model.predict(base_prompt="a frog dressed as a knight")

Token indices sequence length is longer than the specified maximum sequence length for this model (98 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['bian gaze. this knightly frog is set against a background of lily pads floating on a serene pond.]']
Token indices sequence length is longer than the specified maximum sequence length for this model (98 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['bian gaze. this knightly frog is set against a background of lily pads floating on a serene pond.]']


  0%|          | 0/50 [00:00<?, ?it/s]

🍩 https://wandb.ai/geekyrakshit/prompt-upsampling-diffusion/r/call/caafbfbf-2ea7-410c-86ac-80adb561dfe2


## Building a Multi-Modal Evaluation Judge using DSPy

Let's also not try to implement an LLM-assisted evaluation strategy to automatically evaluate our generated images for **prompt-following**, i.e., how accurately the generated image follows the corresponding base prompt. To implement this metric, we use a multi-modal LLM like GPT4-O to look at the generated images and the base prompt and ask it to assign a correctness score between 0 and 1 and justify the score with an explanation.

### Building a Custom Multi-modal OpenAI Interface for DSPy

DSPy doesn't natively support multi-modal prompts. Hence, we first build a custom language model interface called `DSPyOpenAIMultiModalLM` on top of `dsp.GPT3` and implement the logic for interpreting multi-modal prompts. This class can now act as a drop-in replacement for [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) for multi-modal prompts with [base64 encoded images](https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images).

In [None]:
from dsp import GPT3
from openai import OpenAI


class DSPyOpenAIMultiModalLM(GPT3):

    def __init__(
        self,
        model: str = "gpt-4o",
        api_key: str | None = None,
        system_prompt: str | None = None,
        **kwargs,
    ):
        super().__init__(
            model,
            api_key,
            api_provider="openai",
            api_base=None,
            model_type=None,
            system_prompt=system_prompt,
            **kwargs,
        )
        self.model_type = model
        self._openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

    @weave.op()
    def create_messages(self, prompt: str):
        images = find_base64_images(prompt)
        for image in images:
            prompt = prompt.replace(image, "")

        user_prompt = [{"type": "text", "text": prompt}]
        for image in images:
            user_prompt.append({"type": "image_url", "image_url": {"url": image}})
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": user_prompt})
        return messages

    @weave.op()
    def basic_request(self, prompt: str, **kwargs):
        messages = self.create_messages(prompt)
        response = self._openai_client.chat.completions.create(
            model=self.model_type, messages=messages, **kwargs
        )
        self.history.append({"prompt": prompt, "response": response, "kwargs": kwargs})
        return response

    @weave.op()
    def request(self, prompt: str, **kwargs):
        return super().request(prompt, **kwargs)

    @weave.op()
    def __call__(
        self, prompt: str, only_completed: bool = True, **kwargs
    ) -> list:
        response = self.request(prompt, **kwargs)
        choices = (
            [choice for choice in response.choices if choice.finish_reason == "stop"]
            if only_completed and len(response.choices) != 0
            else response.choices
        )
        return [choice.message.content for choice in choices]

### Using DSPy Typed Predictors to Ensure Structured Outputs

Next, we define the judge module's DSPy signature to structure the inputs and outputs according to a fixed [pydantic](https://docs.pydantic.dev/latest/) schema. When building the predictor for the `JudgeSignature`, we use [`dspy.TypedPredictor`](https://dspy-docs.vercel.app/docs/building-blocks/typed_predictors) that lets us provide the input and parse the output of the module in a structured manner that is consistent with the pydantic schema.

In [None]:
from pydantic import BaseModel, Field


class JudgeInput(BaseModel):
    base_prompt: str = Field(description="The base prompt used to generate the image")
    generated_image: str = Field(description="The generated image")


class JudgeMent(BaseModel):
    think_out_loud: str = Field(
        description="Think out loud about your eventual judgement"
    )
    score: float = Field(description="A score between 0 and 1")
    judgement: str = Field(description="Output either 'correct' or 'incorrect'")


class JudgeSignature(dspy.Signature):
    input: JudgeInput = dspy.InputField()
    output: JudgeMent = dspy.OutputField()


class MultiModalJudgeModule(dspy.Module):

    def __init__(self):
        self.prog = dspy.TypedPredictor(JudgeSignature)

    @weave.op()
    def forward(self, base_prompt: str, generated_image: str) -> dict:
        return self.prog(
            input=JudgeInput(
                base_prompt=base_prompt, generated_image=generated_image
            )
        ).output


judgement_module = MultiModalJudgeModule()

### Building the Judge as a Weave Model

We will adopt the evaluation prompt from Appendix D of the paper [Improving Image Generation with Better Captions](https://cdn.openai.com/papers/dall-e-3.pdf) as the multi-modal judge's system prompt.

In [None]:
JUDGE_SYSTEM_PROMPT = """
You are responsible for judging the faithfulness of images generated by a computer program to the
base prompt used to generate them. You will be presented with an image and given the base prompt
that was used to produce the image. The base prompts you are judging are designed to stress-test
image generation programs, and may include things such as:
1. Scrambled or mis-spelled words (the image generator should an image associated with
    the probably meaning).
2. Color assignment (the image generator should apply the correct color to the correct object).
3. Color assignment (the image generator should apply the correct color to the correct object).
4. Abnormal associations, for example 'elephant under a sea', where the image should depict
    what is requested.
5. Descriptions of objects, the image generator should draw the most commonly associated object.
6. Rare single words, where the image generator should create an image somewhat associable with
    the specified image.
7. Images with text in them, where the image generator should create an image with the specified
    text in it. You need to make a decision as to whether or not the image is correct, given the
    base prompt.

You will first think out loud about your eventual judgement, enumerating reasons why the image
does or does not match the given base prompt. After thinking out loud, you should assign a score
between 0 and 1 depending on how much you think the image is faithful to the base prompt. Next,
you should output either 'correct' or 'incorrect' depending on whether you think the image is
faithful to the base prompt.

A few rules:
1. The score should be used to indicate how close the image is to the base prompt in terms of objects,
    color or count; with 0 being very far and 1 being very close.
2. If other objects are present in the image that are not explicitly mentioned by the base prompt,
    assign a higher score.
3. If the objects being displayed is deformed, assign a lower score. Assign a higher score, if the objects
    are displayed in a more detailed manner.
4. 'incorrect' should be reserved for instances where a specific aspect of the base prompt is not followed
    correctly, such as a wrong object, color or count and the score should be less than or equal to 0.5.
"""

Finally, we will write the OpenAI Multi-modal judge as a Weave Model.

In [None]:
class OpenAIJudgeModel(weave.Model):
    openai_model: str
    seed: int
    _judgement_llm: dspy.Module

    def __init__(self, openai_model: str = "gpt-4-turbo", seed: int = 42):
        super().__init__(openai_model=openai_model, seed=seed)
        self._judgement_llm = DSPyOpenAIMultiModalLM(
            model="gpt-4o", system_prompt=JUDGE_SYSTEM_PROMPT, seed=self.seed
        )

    @weave.op()
    def predict(self, base_prompt: str, generated_image: str) -> JudgeMent:
        with dspy.context(lm=self._judgement_llm):
            judgement = judgement_module(base_prompt, generated_image)
        return judgement

    @weave.op()
    def score(self, base_prompt: str, model_output: dict) -> dict:
        judgement: JudgeMent = self.predict(
            base_prompt=base_prompt, generated_image=model_output["image"]
        )
        return {
            "score": judgement.score,
            "is_image_correct": judgement.judgement == "correct",
        }

In [None]:
judge_model = OpenAIJudgeModel()
judgement = judge_model.score("a frog dressed as a knight", sdxl_prediction)

🍩 https://wandb.ai/geekyrakshit/prompt-upsampling-diffusion/r/call/f65464b2-b86d-4e0b-a81b-e0fa7b65ff60
