In [None]:
import base64
import io
import os
from getpass import getpass
from pathlib import Path
from typing import Dict, List, Optional, Union

import dspy
import instructor
import rich
import torch
import weave
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
from PIL import Image
from pydantic import BaseModel
from openai import  OpenAI

In [None]:
api_key = getpass("Enter you OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = api_key
os.environ["WEAVE_PARALLELISM"] = "1"

weave.init(project_name="geekyrakshit/diffusion-prompt-upsample")

dspy.settings.configure(trace=[])

In [None]:
EXT_TO_MIMETYPE = {
    ".jpg": "image/jpeg",
    ".png": "image/png",
    ".svg": "image/svg+xml",
}


def base64_encode_image(
    image_path: Union[str, Image.Image], mimetype: Optional[str] = None
) -> str:
    image = Image.open(image_path) if isinstance(image_path, str) else image_path
    mimetype = (
        EXT_TO_MIMETYPE[Path(image_path).suffix]
        if isinstance(image_path, str)
        else "image/png"
    )
    byte_arr = io.BytesIO()
    image.save(byte_arr, format="PNG")
    encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
    encoded_string = f"data:{mimetype};base64,{encoded_string}"
    return str(encoded_string)

In [None]:
upsampler_llm = dspy.OpenAI(
    model="gpt-4o",
    system_prompt="""
You are part of a team of bots that creates images. You work with an assistant bot that will draw anything
you say in square brackets. For example, outputting "a beautiful morning in the woods with the sun peaking
through the trees" will trigger your partner bot to output an image of a forest morning, as described.
You will be prompted by people looking to create detailed, amazing images. The way to accomplish this is to
take their short prompts and make them extremely detailed and descriptive.

There are a few rules to follow:
- You will only ever output a single image description per user request.
- Often times, the base prompt might consist of spelling mistakes or grammatical errors. You should correct
    such errors before making them extremely detailed and descriptive.
- Image descriptions must be between 15-80 words. Extra words will be ignored.
"""
)

In [None]:
class PromptUpsamplingSignature(dspy.Signature):
    base_prompt = dspy.InputField()
    answer = dspy.OutputField(
        desc="Create an imaginative image descriptive caption for the given base prompt."
    )

In [None]:
class StableDiffusionXLModel(weave.Model):

    model_name_or_path: str
    enable_cpu_offfload: bool
    completions: Optional[List[dspy.Prediction]] = None
    diffusion_prompt_upsampler: Optional[dspy.Module] = None
    stock_negative_prompt: Optional[str] = None
    _pipeline: DiffusionPipeline

    def __init__(self, model_name_or_path: str, enable_cpu_offfload: bool):
        super().__init__(
            model_name_or_path=model_name_or_path,
            enable_cpu_offfload=enable_cpu_offfload,
        )
        self.completions = self.get_completion_rationales() if self.completions is None else self.completions
        self.diffusion_prompt_upsampler = dspy.MultiChainComparison(
            PromptUpsamplingSignature, M=len(self.completions)
        )
        self.stock_negative_prompt = "frame, border, 2d, ugly, static, dull, monochrome, distorted face, deformed fingers, scary, horror, nightmare, deformed lips, deformed eyes, deformed hands, deformed legs, impossible physics, absurdly placed objects" if self.stock_negative_prompt is None else self.stock_negative_prompt
        self._pipeline = AutoPipelineForText2Image.from_pretrained(
            self.model_name_or_path,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True
        )
        if self.enable_cpu_offfload:
            self._pipeline.enable_model_cpu_offload()
        else:
            self._pipeline = self._pipeline.to("cuda")
    
    def get_completion_rationales(self) -> List[dspy.Prediction]:
        return [
            dspy.Prediction(
                rationale="a man holding a sword",
                answer="a pale figure with long white hair stands in the center of a dark forest, holding a sword high above his head."),
            dspy.Prediction(
                rationale="a frog playing dominoes",
                answer="a frog sits on a worn table playing a game of dominoes with an elderly raccoon. the table is covered in a green cloth, and the frog is wearing a jacket and a pair of jeans. The scene is set in a forest, with a large tree in the background."
            ),
            dspy.Prediction(
                rationale="A bird scaring a scarecrow",
                answer="A large, vibrant bird with an impressive wingspan swoops down from the sky, letting out a piercing call as it approaches a weathered scarecrow in a sunlit field. The scarecrow, dressed in tattered clothing and a straw hat, appears to tremble, almost as if it's coming to life in fear of the approaching bird."
            ),
            dspy.Prediction(
                rationale="Paying for a quarter-sized pizza with a pizza-sized quarter",
                answer="A person is standing at a pizza counter, holding a gigantic quarter the size of a pizza. The cashier, wide-eyed with astonishment, hands over a tiny, quartersized pizza in return. The background features various pizza toppings and other customers, all of them equally amazed by the unusual transaction."
            ),
            dspy.Prediction(
                rationale="a quilt with an iron on it",
                answer="a quilt is laid out on a ironing board with an iron resting on top. the quilt has a patchwork design with pastel-colored strips of fabric and floral patterns. the iron is turned on and the tip is resting on top of one of the strips. the quilt appears to be in the process of being pressed, as the steam from the iron is visible on the surface. the quilt has a vintage feel and the colors are yellow, blue, and white, giving it an antique look."
            ),
            dspy.Prediction(
                rationale="a furry humanoid skunk",
                answer="In a fantastical setting, a highly detailed furry humanoid skunk with piercing eyes confidently poses in a medium shot, wearing an animal hide jacket. The artist has masterfully rendered the character in digital art, capturing the intricate details of fur and clothing texture."
            ),
            dspy.Prediction(
                rationale="An icy landscape under a starlit sky",
                answer="An icy landscape under a starlit sky, where a magnificent frozen waterfall flows over a cliff. In the center of the scene, a f ire burns bright, its flames seemingly frozen in place, casting a shimmering glow on the surrounding ice and snow."
            ),
            dspy.Prediction(
                rationale="A fierce garden gnome warrior",
                answer="A fierce garden gnome warrior, clad in armor crafted from leaves and bark, brandishes a tiny sword and shield. He stands valiantly on a rock amidst a blooming garden, surrounded by colorful flowers and towering plants. A determined expression is painted on his face, ready to defend his garden kingdom."
            ),
            dspy.Prediction(
                rationale="A ferret in a candy jar",
                answer="A mischievous ferret with a playful grin squeezes itself into a large glass jar, surrounded by colorful candy. The jar sits on a wooden table in a cozy kitchen, and warm sunlight filters through a nearby window."
            ),
            dspy.Prediction(
                rationale="cartoon drawing of an astronaut riding a horse",
                answer="Cartoon drawing of an outer space scene. Amidst floating planets and twinkling stars, a whimsical horse with exaggerated features rides an astronaut, who swims through space with a jetpack, looking a tad overwhelmed."
            ),
            dspy.Prediction(
                rationale="A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.",
                answer="A small vessel, propelled on water by oars, sails, or an engine, floats gracefully on a serene lake. the sun casts a warm glow on the water, reflecting the vibrant colors of the sky as birds fly overhead."
            ),
        ]

    @weave.op()
    def predict(
        self,
        base_prompt: str,
        negative_prompt: Optional[str] = None,
        num_inference_steps: Optional[int] = 50,
        image_size: Optional[int] = 1024,
        guidance_scale: Optional[float] = 7.0,
        upsample_prompt: Optional[bool] = True,
    ) -> str:
        prompt_upsampler_response = self.diffusion_prompt_upsampler(
            self.completions, base_prompt=base_prompt
        ).answer if upsample_prompt else base_prompt
        negative_prompt = negative_prompt if negative_prompt is not None else self.stock_negative_prompt
        image = self._pipeline(
            prompt=prompt_upsampler_response,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            height=image_size,
            width=image_size,
            guidance_scale=guidance_scale,
        ).images[0]
        return base64_encode_image(image)

In [None]:
diffusion_model = StableDiffusionXLModel(
    model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
    enable_cpu_offfload=True
)

In [None]:
with dspy.context(lm=upsampler_llm):
    base_prompt_image = diffusion_model.predict(
        base_prompt="girl with wild hairs", upsample_prompt=False
    )
    upsampled_prompt_image = diffusion_model.predict(
        base_prompt="girl with wild hairs", upsample_prompt=True
    )

In [None]:
class Judgement(BaseModel):
    think_out_loud: str
    score: float
    judgement: str
    aesthetic_score: float


class OpenAIJudgeModel(weave.Model):
    openai_model: str
    max_retries: int
    seed: int
    _openai_client: Optional[instructor.Instructor] = None
    
    def __init__(self, openai_model: str = "gpt-4-turbo", max_retries: int = 5, seed: int = 42):
        super().__init__(
            openai_model=openai_model,
            max_retries=max_retries,
            seed=seed
        )
        self._openai_client = instructor.from_openai(
            OpenAI(api_key=os.environ["OPENAI_API_KEY"])
        )
    
    @weave.op()
    def predict(self, base_prompt: str, generated_image: str) -> Dict:
        return self._openai_client.chat.completions.create(
            model=self.openai_model,
            response_model=Judgement,
            max_retries=self.max_retries,
            seed=self.seed,
            messages=[
                {
                    "role": "system",
                    "content": """
You are responsible for judging the faithfulness of images generated by a computer program to the
base prompt used to generate them and the overall aesthetics of the image. You will be presented
with an image and given the base prompt that was used to produce the image. The base prompts you
are judging are designed to stress-test image generation programs, and may include things such as:
1. Scrambled or mis-spelled words (the image generator should an image associated with
    the probably meaning).
2. Color assignment (the image generator should apply the correct color to the correct object).
3. Color assignment (the image generator should apply the correct color to the correct object).
4. Abnormal associations, for example 'elephant under a sea', where the image should depict
    what is requested.
5. Descriptions of objects, the image generator should draw the most commonly associated object.
6. Rare single words, where the image generator should create an image somewhat associable with
    the specified image.
7. Images with text in them, where the image generator should create an image with the specified
    text in it. You need to make a decision as to whether or not the image is correct, given the
    base prompt.

You will first think out loud about your eventual judgement, enumerating reasons why the image
does or does not match the given base prompt and the overall aesthetics of the image (such as
how vibrant the color palette is, how visually pleasing the image looks, and how much depth and
details the image shows). After thinking out loud, you should assign a score between 0 and 1
depending on how much you think the image is faithful to the base prompt. Next, you should output
either 'correct' or 'incorrect' depending on whether you think the image is faithful to the base
prompt. Finally, you will also assign an aesthetic score between 0 and 1 depending on how
aesthetically pleasing you find the image.

A few rules:
1. The score should be used to indicate how close the image is to the base prompt in terms of objects,
    color or count; with 0 being very far and 1 being very close.
2. Ignore other objects in the image that are not explicitly mentionedby the base prompt; it is fine for
    these to be shown.
3. Ignore other objects in the image that are not explicitly mentionedby the base prompt; it is fine for
    these to be shown.
4. 'incorrect' should be reserved for instances where a specific aspect of the base prompt is not followed
    correctly, such as a wrong object, color or count and the score should be less than or equal to 0.5.
5. The aesthetic score should be used to indicate how visually appealing you find the image, with 0 being
    close to dull, monochrome, and gray with the image lacking any depth, and 1 being very visually 
    appealing with a vibrant color palette and the image showing good depth and detail.
""",
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"The base prompt is '{base_prompt}'"},
                        {"type": "image_url", "image_url": {"url": generated_image}},
                    ],
                },
            ],
        )
    
    @weave.op()
    def score(self, base_prompt: str, model_output: Dict) -> Dict:
        judgement: Judgement = self.predict(base_prompt=base_prompt, generated_image=model_output)
        return {
            "correctness": judgement.judgement,
            "score": judgement.score,
            "aesthetic_score": judgement.aesthetic_score,
        }

In [None]:
judge_model = OpenAIJudgeModel()

In [None]:
base_prompt = "One cat and one dog sitting on the grass."

with dspy.context(lm=upsampler_llm):
    base_prompt_image = diffusion_model.predict(
        base_prompt=base_prompt, upsample_prompt=False
    )
    upsampled_prompt_image = diffusion_model.predict(
        base_prompt=base_prompt, upsample_prompt=True
    )
judgement_base_prompt_image = judge_model.predict(
    base_prompt=base_prompt, generated_image=base_prompt_image
)
judgement_upsampled_prompt_image = judge_model.predict(
    base_prompt=base_prompt, generated_image=upsampled_prompt_image
)

rich.print(f"{judgement_base_prompt_image=}")
rich.print(f"{judgement_upsampled_prompt_image=}")