In [23]:
"""
Pipeline:
  1. gpt-4o-mini → breaks story into 6 visual scenes.
  2. gpt-4o-mini → creates detailed image prompts preserving Aldar Kose’s art style.
  3. gpt-image-1 → generates images, using reference images for style locking:
        - Scene 1: uses pre-defined Aldar Kose reference images.
        - Scene 2+: uses previously generated scene(s) as visual references.
"""

from openai import OpenAI
import base64
import os
from dotenv import load_dotenv

load_dotenv()
api_key = os.environ.get("OPENAI_API_KEY")

client = OpenAI(api_key=api_key)

# Predefined reference images for Aldar Kose’s art style
BASE_STYLE_IMAGES_FOLDER = "base-images"

In [24]:
def generate_scene_descriptions(story: str) -> list[str]:
    """Split the short story into 8 visual scenes."""
    prompt = f"""
    Break this short story about Aldar Kose into **exactly** 8 distinct visual scenes.
    Each visual scene should describe the key visual actions, characters in 1-2 sentences.

    Story:
    {story}

    Make sure to create a list of **exactly** 8 scenes.
    Return as a numbered list.
    """
    res = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}]
    )
    scenes = res.choices[0].message.content.strip().split("\n")
    return [s.strip("0123456789. ").strip() for s in scenes if s.strip()]

In [25]:
def generate_image_prompt(scene: str, scenes: [str]) -> str:
    """Produce a detailed image prompt preserving Aldar Kose’s art style and referencing style images."""
    prompt = f"""
    You are generating an image prompt for the GPT image model.

Follow this exact format for the output:
"Referenced images show Kazakh folk character Aldar Kose and previous scenes and should be used to preserve his look, proportions, and clothing style. 
[Scene Description].
Vibrant comic-book style, bold outlines, warm earthy colors, soft shading, expressive faces, Kazakh steppe, traditional clothing, realistic proportions."

Scene description is generated from: {scene}.

Pay attention to other scenes: {scenes}.

Output only the final prompt.
    """
    res = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}]
    )
    return res.choices[0].message.content.strip()

In [26]:
def save_base64_image(b64_data: str, filename: str, folder: str) -> str:
    """Decode and save a base64 image to disk."""
    os.makedirs(folder, exist_ok=True)
    path = os.path.join(folder, filename)
    with open(path, "wb") as f:
        f.write(base64.b64decode(b64_data))
    return os.path.abspath(path)

import glob

def load_base_style_images(folder: str):
    """Load all image files from the base style folder."""
    exts = ("*.jpg", "*.jpeg", "*.png")
    paths = []
    for ext in exts:
        paths.extend(glob.glob(os.path.join(folder, ext)))
    if not paths:
        raise FileNotFoundError(f"No style images found in {folder}")
    return paths

def generate_images_from_story(story: str, folder: str):
    """Main pipeline for sequential image generation with style continuity."""
    scenes = generate_scene_descriptions(story)
    print(f"Generated {len(scenes)} scenes.\n")
    print(scenes)

    base_image_paths = load_base_style_images(BASE_STYLE_IMAGES_FOLDER)
    prev_image_files = [open(p, "rb") for p in base_image_paths]
    generated_image_paths = []

    for i, scene in enumerate(scenes, 1):
        image_prompt = generate_image_prompt(scene, scenes)
        print(f"Scene {i} prompt:\n{image_prompt}\n")

        # Generate image with current reference images
        image = client.images.edit(
            model="gpt-image-1",
            prompt=image_prompt,
            image=prev_image_files,
            size="1024x1024",
            n=1,
            quality="high",
            input_fidelity="high"
        )

        # Save output
        image_path = save_base64_image(image.data[0].b64_json, f"scene_{i}.png", folder)
        generated_image_paths.append(image_path)
        print(f"Saved {image_path}")

        prev_image_files.append(open(image_path, "rb"))  

    # Clean up base style files
    for f in prev_image_files:
        f.close()
    return generated_image_paths

In [27]:
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def show_image_grid(images, cols=3, titles=None, cell_size=(3, 3), pad=0.1):
    """
    Display images in a grid (Jupyter-friendly).

    Args:
      images: list of file paths, PIL.Image.Image, or numpy arrays (H,W[,C]).
      cols: number of columns.
      titles: optional list of strings for per-image captions.
      cell_size: (width, height) in inches per cell.
      pad: padding for tight_layout.

    Returns:
      None. Renders the grid inline.
    """
    n = len(images)
    if n == 0:
        return
    cols = max(1, int(cols))
    rows = math.ceil(n / cols)

    fig_w = cols * cell_size[0]
    fig_h = rows * cell_size[1]
    fig, axes = plt.subplots(rows, cols, figsize=(fig_w, fig_h))
    axes = np.atleast_2d(axes)  # ensure 2D array

    for idx in range(rows * cols):
        r, c = divmod(idx, cols)
        ax = axes[r, c]
        ax.axis("off")
        if idx >= n:
            continue

        img = images[idx]
        if isinstance(img, str):
            img = Image.open(img)
        elif isinstance(img, np.ndarray):
            if img.dtype != np.uint8:
                # normalize to 0..255 for display
                arr = img.astype(np.float32)
                arr = (255 * (arr - arr.min()) / (arr.ptp() + 1e-8)).astype(np.uint8)
                img = Image.fromarray(arr)
            else:
                img = Image.fromarray(img)
        # if already PIL.Image.Image, pass through

        ax.imshow(img)
        if titles and idx < len(titles) and titles[idx]:
            ax.set_title(titles[idx], fontsize=9)

    plt.tight_layout(pad=pad)
    plt.show()


In [28]:
def create_storyboard(story: str, folder: str):
    generated_image_paths = generate_images_from_story(story, folder)
    show_image_grid(generated_image_paths)

In [29]:
import os
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, HTML

def run_task(prompt: str, out_dir: Path) -> Path:
    """
    Replace this with your real pipeline.
    Currently writes the prompt to a timestamped .txt file.
    """
    create_storyboard(prompt, out_dir)
    return out_dir

# ---- widgets ----
prompt_box = widgets.Textarea(
    value="",
    placeholder="Type your prompt...",
    description="Prompt:",
    layout=widgets.Layout(width="100%", height="120px"),
)

folder_box = widgets.Text(
    value="outputs",
    placeholder="folder to store results",
    description="Folder:",
    layout=widgets.Layout(width="50%"),
)

run_btn = widgets.Button(
    description="Run",
    button_style="primary",
    tooltip="Execute with current inputs",
    icon="play",
)

status = widgets.Output()

def _sanitize_folder(name: str) -> str:
    # basic safety: strip and forbid empty; remove trailing slashes
    name = name.strip().rstrip("/").rstrip("\\")
    return name or "outputs"

def on_run_click(_):
    with status:
        status.clear_output()
        prompt = prompt_box.value.strip()
        folder = _sanitize_folder(folder_box.value)

        if not prompt:
            display(HTML("<b style='color:#b00020'>Enter a prompt.</b>"))
            return

        out_dir = Path(folder).expanduser().resolve()
        try:
            out_dir.mkdir(parents=True, exist_ok=True)
        except Exception as e:
            display(HTML(f"<b style='color:#b00020'>Failed to create folder:</b> {e}"))
            return

        try:
            out_path = run_task(prompt, out_dir)
        except Exception as e:
            display(HTML(f"<b style='color:#b00020'>Task error:</b> {e}"))
            return

        rel = os.path.relpath(out_path, Path.cwd())

run_btn.on_click(on_run_click)

# ---- layout ----
ui = widgets.VBox([
    prompt_box,
    widgets.HBox([folder_box, run_btn]),
    status
])

display(ui)


VBox(children=(Textarea(value='', description='Prompt:', layout=Layout(height='120px', width='100%'), placehol…