In [None]:
import torch
from diffusers import StableDiffusionXLPipeline, ControlNetModel
import numpy as np
from typing import List
import os
from PIL import Image
import cv2
import matplotlib.pyplot as plt

def generate_canny_edge_map(prompt: str, width: int = 512, height: int = 512) -> Image.Image:
    img = np.zeros((height, width, 3), dtype=np.uint8)
    if "mountain" in prompt.lower():
        cv2.line(img, (100, 400), (256, 100), (255, 255, 255), 3)
        cv2.line(img, (256, 100), (412, 400), (255, 255, 255), 3)
    elif "flower" in prompt.lower():
        cv2.circle(img, (256, 256), 40, (255, 255, 255), -1)
    elif "cave" in prompt.lower():
        cv2.rectangle(img, (180, 180), (330, 330), (255, 255, 255), -1)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    return Image.fromarray(edges)

def split_story_into_prompts(story: str, num_images: int = 6, base_description_girl: str = "", base_description_flower: str = "", base_description_mountain: str = "", style_description: str = "") -> List[str]:
    sentences = story.replace(".", ".\n").split("\n")
    sentences = [s.strip() for s in sentences if s.strip()]
    if len(sentences) < num_images:
        sentences = sentences * (num_images // len(sentences) + 1)

    prompts = []
    for i in range(min(num_images, len(sentences))):
        if i == 0:
            prompt = f"{sentences[i]}, a magical flower {base_description_flower} on a snowy mountain peak, {base_description_mountain}, {style_description}"
        elif i == 1:
            prompt = f"{sentences[i]}, a girl {base_description_girl} hiking with a backpack, mountain trail background, {base_description_mountain}, {style_description}"
        elif i == 2:
            prompt = f"{sentences[i]}, a girl {base_description_girl} walking through thick clouds on the mountain, {base_description_mountain}, {style_description}"
        elif i == 3:
            prompt = f"{sentences[i]}, a frightened girl {base_description_girl} walking faster, mysterious mountain forest atmosphere, {base_description_mountain}, {style_description}"
        elif i == 4:
            prompt = f"{sentences[i]}, a glowing cave entrance hidden behind rocks, {base_description_mountain}, {style_description}"
        else:
            prompt = f"{sentences[i]}, a magical flower {base_description_flower} glowing inside the cave, golden light, mystical, {style_description}"
        prompts.append(prompt)
    return prompts

def check_image_validity(img: Image.Image) -> bool:
    img_array = np.array(img)
    return img_array.max() > 0

def generate_image_sequence(story: str, output_dir: str, model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
                           controlnet_model_id: str = "diffusers/controlnet-canny-sdxl-1.0", num_images: int = 6,
                           device: str = "cuda", seed: int = 42):
    try:
        os.makedirs(output_dir, exist_ok=True)
        print("Loading Stable Diffusion XL and ControlNet models...")
        controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)
        pipe = StableDiffusionXLPipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.float16)
        pipe = pipe.to(device)

        base_description_girl = "young woman, wearing a red jacket and hiking boots, determined expression, detailed, high quality"
        base_description_flower = "large petals, glowing softly, magical aura, bioluminescent, detailed"
        base_description_mountain = "snow-capped mountain, foggy cliffs, high altitude, detailed, high quality"
        style_description = "storybook watercolor style, soft tones, mystical atmosphere, no shadows"
        prompts = split_story_into_prompts(story, num_images, base_description_girl, base_description_flower, base_description_mountain, style_description)
        print(f"Generated prompts: {prompts}")

        negative_prompt = "blurry, low quality, distorted, extra limbs, unnatural colors, black image, blank image, inconsistent clothes, realistic style, shadows, cartoonish, low detail"

        images = []
        for i, prompt in enumerate(prompts):
            print(f"Generating image {i+1}/{num_images} for prompt: {prompt[:50]}...")
            control_image = generate_canny_edge_map(prompt)
            generator = torch.Generator(device=device).manual_seed(seed + i)
            result = pipe(
                prompt,
                num_inference_steps=50,
                guidance_scale=7.5,
                negative_prompt=negative_prompt,
                num_images_per_prompt=1,
                generator=generator,
                image=control_image,
                controlnet_conditioning_scale=0.5
            ).images
            if result:
                images.append(result[0])
            else:
                print(f"Warning: Failed to generate image for prompt {i+1}!")

        plt.figure(figsize=(15, 10))
        valid_images = []
        for i, img in enumerate(images):
            img_path = os.path.join(output_dir, f"image_{i+1}.png")
            img.save(img_path)
            print(f"Saved {img_path}")
            if check_image_validity(img):
                valid_images.append(img)
            else:
                print(f"Warning: Image {i+1} is black or invalid!")
            plt.subplot(2, 3, i+1)
            plt.imshow(img)
            plt.title(f"Frame {i+1}: {prompts[i][:30]}...", fontsize=8)
            plt.axis('off')

        plt.tight_layout()
        plt.show()

        if not valid_images:
            print("Error: All generated images are black or invalid!")
        return valid_images

    except Exception as e:
        print(f"Error during image generation: {str(e)}")
        return None

# ==== Run this block ====
if __name__ == "__main__":
    story = (
        "A rumor spread that a magical flower bloomed once a year on a mountain peak. "
        "She packed her backpack and started hiking toward the mountain. "
        "As she climbed higher, thick clouds surrounded her. "
        "She heard strange whispers and quickened her pace, heart pounding. "
        "Suddenly, she stumbled upon a glowing cave hidden behind some rocks. "
        "Inside the cave, the magical flower shone brightly, surrounded by golden light."
    )
    output_dir = "/content/story_flower_images"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    images = generate_image_sequence(story, output_dir, num_images=6, device=device, seed=42)