In [None]:
!pip install open_clip_torch diffusers transformers accelerate scikit-image gradio --quiet

import torch
from diffusers import StableDiffusionXLPipeline, ControlNetModel
import numpy as np
from typing import List
import os
from PIL import Image
import cv2
import gradio as gr
import open_clip
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim

# Load CLIP model
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model.eval()

# CLIP score calculation
def compute_clip_score(prompt: str, image: Image.Image) -> float:
    image_input = preprocess(image).unsqueeze(0)
    text_input = tokenizer([prompt])
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_input)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (image_features @ text_features.T).item()
    return similarity

# SSIM calculation
def compute_ssim_score(images: List[Image.Image]) -> float:
    scores = []
    for i in range(len(images)-1):
        img1 = np.array(images[i].convert("L"))
        img2 = np.array(images[i+1].convert("L"))
        s = ssim(img1, img2)
        scores.append(s)
    return float(np.mean(scores)) if scores else 0.0

# Diversity score
def calculate_diversity_score(images: List[Image.Image]) -> float:
    np_imgs = np.stack([np.array(img).astype(np.float32)/255.0 for img in images], axis=0)
    stddev = np.std(np_imgs, axis=0).mean()
    return float(stddev)

# Canny edge map generator
def generate_canny_edge_map(prompt: str, width: int = 512, height: int = 512) -> Image.Image:
    img = np.zeros((height, width, 3), dtype=np.uint8)
    if "mountain" in prompt.lower():
        cv2.line(img, (100, 400), (256, 100), (255, 255, 255), 3)
        cv2.line(img, (256, 100), (412, 400), (255, 255, 255), 3)
    elif "flower" in prompt.lower():
        cv2.circle(img, (256, 256), 40, (255, 255, 255), -1)
    elif "cave" in prompt.lower():
        cv2.rectangle(img, (180, 180), (330, 330), (255, 255, 255), -1)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    return Image.fromarray(edges)

# Simple flat prompt builder
def flat_prompt_generation(
    story: str, num_images: int = 6,
    char_desc: str = "", obj_desc: str = "",
    env_desc: str = "", style_desc: str = ""
) -> List[str]:
    sentences = story.replace(".", ".\n").split("\n")
    sentences = [s.strip() for s in sentences if s.strip()]
    if len(sentences) < num_images:
        sentences = sentences * (num_images // len(sentences) + 1)

    prompts = []
    for i in range(num_images):
        base_prompt = f"{sentences[i]}"
        if char_desc: base_prompt += f", {char_desc}"
        if obj_desc: base_prompt += f", {obj_desc}"
        if env_desc: base_prompt += f", {env_desc}"
        if style_desc: base_prompt += f", {style_desc}"
        prompts.append(base_prompt)
    return prompts

# Smart handcrafted prompt builder
def story_specific_prompt_generation(
    story: str, num_images: int = 6,
    char_desc: str = "", obj_desc: str = "",
    env_desc: str = "", style_desc: str = ""
) -> List[str]:
    sentences = story.replace(".", ".\n").split("\n")
    sentences = [s.strip() for s in sentences if s.strip()]
    if len(sentences) < num_images:
        sentences = sentences * (num_images // len(sentences) + 1)

    prompts = []
    for i in range(num_images):
        s = sentences[i]
        if i == 0:
            prompt = f"{s}, a radiant magical flower {obj_desc}, high on a mystical peak, {env_desc}, {style_desc}"
        elif i == 1:
            prompt = f"{s}, a young explorer {char_desc} hiking at sunset, misty mountain backdrop, {env_desc}, {style_desc}"
        elif i == 2:
            prompt = f"{s}, girl {char_desc} walking through glowing clouds, magical mountain vibe, {env_desc}, {style_desc}"
        elif i == 3:
            prompt = f"{s}, scared girl {char_desc} rushing along an eerie forest trail, {env_desc}, {style_desc}"
        elif i == 4:
            prompt = f"{s}, ancient cave opening behind glowing rocks, mystical aura, {env_desc}, {style_desc}"
        else:
            prompt = f"{s}, magical flower {obj_desc} inside the glowing cave, shining golden light, {style_desc}"
        prompts.append(prompt)
    return prompts

def check_image_validity(img: Image.Image) -> bool:
    return np.array(img).max() > 0

# Main pipeline
def generate_images_for_ui(
    story, char_desc, obj_desc, env_desc, style_desc,
    num_images, seed, prompt_style
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    output_dir = "story_images"
    os.makedirs(output_dir, exist_ok=True)

    try:
        controlnet = ControlNetModel.from_pretrained(
            "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
        ).to(device)
        pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet,
            torch_dtype=torch.float16
        ).to(device)

        # Choose prompt strategy
        if prompt_style == "Simple":
            prompts = flat_prompt_generation(story, num_images, char_desc, obj_desc, env_desc, style_desc)
        else:
            prompts = story_specific_prompt_generation(story, num_images, char_desc, obj_desc, env_desc, style_desc)

        negative_prompt = "blurry, low quality, distorted, extra limbs, unnatural colors, black image, blank image"

        images, clip_scores = [], []
        for i, prompt in enumerate(prompts):
            control_image = generate_canny_edge_map(prompt)
            generator = torch.Generator(device=device).manual_seed(seed + i)
            result = pipe(
                prompt,
                num_inference_steps=60,
                guidance_scale=8.5,
                negative_prompt=negative_prompt,
                num_images_per_prompt=1,
                generator=generator,
                image=control_image,
                controlnet_conditioning_scale=0.5
            ).images
            img = result[0]
            img_path = os.path.join(output_dir, f"image_{i+1}.png")
            img.save(img_path)
            if check_image_validity(img):
                images.append(img)
                clip_scores.append(compute_clip_score(prompt, img))
            else:
                fallback = Image.new("RGB", (512, 512), (0, 0, 0))
                images.append(fallback)
                clip_scores.append(0.0)

        avg_clip = sum(clip_scores) / len(clip_scores)
        diversity = calculate_diversity_score(images)
        ssim_score = compute_ssim_score(images)
        return images, round(avg_clip, 4), round(diversity, 4), round(ssim_score, 4)

    except Exception as e:
        print(f"Error: {str(e)}")
        return [], 0.0, 0.0, 0.0

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 📖 Story-to-Image Generator using SDXL + ControlNet + Prompt Styles + Evaluation Metrics")

    with gr.Row():
        story = gr.Textbox(label="Story (multi-sentence)", lines=5, placeholder="Enter your story here...")

    with gr.Row():
        char_desc = gr.Textbox(label="Main Character Description", placeholder="e.g., a brave astronaut")
        obj_desc = gr.Textbox(label="Object/Element Description", placeholder="e.g., a glowing crystal flower")
        env_desc = gr.Textbox(label="Environment Description", placeholder="e.g., snowy mountain, glowing cave")
        style_desc = gr.Textbox(label="Art Style", placeholder="e.g., cinematic, Ghibli, hyperreal")

    with gr.Row():
        prompt_style = gr.Dropdown(["Simple", "Story-Specific Detailed"], value="Story-Specific Detailed", label="🧠 Prompt Template Style")
        num_images = gr.Slider(label=" Number of Images", minimum=1, maximum=6, step=1, value=6)
        seed = gr.Slider(label="Random Seed", minimum=0, maximum=9999, step=1, value=42)

    generate_btn = gr.Button(" Generate")

    with gr.Row():
        gallery = gr.Gallery(label=" Generated Images", columns=3, height="auto")

    with gr.Row():
        clip_output = gr.Textbox(label=" Average CLIP Score", interactive=False)
        diversity_output = gr.Textbox(label=" Diversity Score", interactive=False)
        ssim_output = gr.Textbox(label=" SSIM Score", interactive=False)

    generate_btn.click(
        fn=generate_images_for_ui,
        inputs=[story, char_desc, obj_desc, env_desc, style_desc, num_images, seed, prompt_style],
        outputs=[gallery, clip_output, diversity_output, ssim_output]
    )

demo.launch()


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://1a75fcec5f946e5dc4.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


