In [None]:
import base64
import io
import os
from getpass import getpass
from pathlib import Path
from typing import List, Optional, Union

import torch
import weave
from diffusers import StableDiffusion3Pipeline
from openai import OpenAI
from PIL import Image
from tqdm.auto import tqdm

In [None]:
api_key = getpass("Enter you OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = api_key

In [None]:
EXT_TO_MIMETYPE = {
    ".jpg": "image/jpeg",
    ".png": "image/png",
    ".svg": "image/svg+xml",
}


def base64_encode_image(
    image_path: Union[str, Image.Image], mimetype: Optional[str] = None
) -> str:
    image = Image.open(image_path) if isinstance(image_path, str) else image_path
    mimetype = (
        EXT_TO_MIMETYPE[Path(image_path).suffix]
        if isinstance(image_path, str)
        else "image/png"
    )
    byte_arr = io.BytesIO()
    image.save(byte_arr, format="PNG")
    encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
    encoded_string = f"data:{mimetype};base64,{encoded_string}"
    return str(encoded_string)

In [None]:
class SD3Model(weave.Model):
    model_name_or_path: str
    enable_cpu_offfload: bool
    _pipeline: StableDiffusion3Pipeline

    def __init__(self, model_name_or_path: str, enable_cpu_offfload: bool):
        super().__init__(
            model_name_or_path=model_name_or_path,
            enable_cpu_offfload=enable_cpu_offfload,
        )
        self._pipeline = StableDiffusion3Pipeline.from_pretrained(
            self.model_name_or_path, torch_dtype=torch.float16
        )
        if self.enable_cpu_offfload:
            self._pipeline.enable_model_cpu_offload()
        else:
            self._pipeline = self._pipeline.to("cuda")

    @weave.op()
    def predict(
        self,
        prompt: str,
        negative_prompt: str,
        num_inference_steps: int,
        image_size: int,
        guidance_scale: float,
    ) -> str:
        image = self._pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            height=image_size,
            width=image_size,
            guidance_scale=guidance_scale,
        ).images[0]
        return base64_encode_image(image)

In [None]:
class OpenAIModel(weave.Model):
    model: str
    max_retries: int = 5
    seed: int = 42
    _openai_client: OpenAI

    def __init__(self, model: str):
        super().__init__(model=model)
        self._openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

    @weave.op()
    def predict(self, story: str, context: str, paragraph: str) -> str:
        return self._openai_client.chat.completions.create(
            model=self.model,
            seed=self.seed,
            messages=[
                {
                    "role": "system",
                    "content": """
You are a helpful assistant to a visionary film director. You will be provided with a story, some context around
the story, and a paragraph from the story. Given this information, you are supposed to summarize the paragraph in
less than 40 words such that the summary provides a detailed and accurate visual description of the paragraph
which could be used by the director and his crew to set up a scene and do a photoshoot. The summary should capture
visual cues from the time and setting of the story from the context as well as visual cues from the entire story.
If there are any characters or objects in the paragraph, they should be visually described in the summary,
by looking for clues from the story.
                    """,
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"""
Story:
---
{story}
---

Context around the story:
---
{context}
---

Paragraph:
---
{paragraph}
---
""",
                        },
                    ],
                },
            ],
        ).choices[0].message.content

In [None]:
weave.init(project_name="geekyrakshit/story-illustration")

In [None]:
context_around_story = """
"The Gift of the Magi" is a short story by O. Henry first published in 1905.
The story tells of a young husband and wife and how they deal with the challenge
of buying secret Christmas gifts for each other with very little money.
As a sentimental story with a moral lesson about gift-giving, it has been popular
for adaptation, especially for presentation at Christmas time. The plot and its twist
ending are well known; the ending is generally considered an example of cosmic irony.
The story was allegedly written at Pete's Tavern on Irving Place in New York City.

The story was initially published in The New York Sunday World under the title
"Gifts of the Magi" on December 10, 1905. It was first published in book form in the
O. Henry collection The Four Million in April 1906.
"""

with open("../../gift_of_the_magi.txt") as f:
    story = f.read()

paragraph = story.split("\n")[0]

In [None]:
openai_model = OpenAIModel(model="gpt-4o")
diffusion_model = SD3Model(
    model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers",
    enable_cpu_offfload=True
)

In [None]:
@weave.op()
def generate_images(story: str, context: str, max_paragraphs: int) -> List[str]:
    images = []
    paragraphs = story.split("\n\n")[:max_paragraphs]
    for paragraph in tqdm(paragraphs):
        if paragraph.strip() != "":
            prompt = openai_model.predict(story=story, context=context, paragraph=paragraph)
            image = diffusion_model.predict(
                prompt=prompt + "surreal style, artstation, digital art, illustration",
                negative_prompt="2d, ugly, distorted face, deformed fingers, scary, horror, nightmare, deformed lips, deformed eyes, deformed hands, deformed legs, impossible physics, absurdly placed objects",
                num_inference_steps=28,
                image_size=1024,
                guidance_scale=7.0,
            )
            images.append(image)
    return images

In [None]:
images = generate_images(story=story, context=context_around_story, max_paragraphs=5)