In [None]:
# kitchen_generation.py

import os, cv2, numpy as np, gc, torch
from PIL import Image
from datetime import datetime
from diffusers import (
    ControlNetModel, AutoencoderKL, StableDiffusionXLControlNetPipeline,
    StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, EulerDiscreteScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from diffusers.utils import load_image
import cohere

# === Initial setup ===
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

os.environ["TMPDIR"] = "/workspace/tmp"
os.environ["XFORMERS_DISABLED"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"

# === Paths ===
CONTROLNET_CANNY_DIR = "/workspace/stable-diffusion-webui/models/ControlNet/controlnet-canny-sdxl"
CONTROLNET_DEPTH_DIR = "/workspace/stable-diffusion-webui/models/ControlNet/controlnet-depth-sdxl"
MODEL_BASE = "/workspace/stable-diffusion-webui/models/sd_xl_base_1.0"
REFINER_DIR = "/workspace/stable-diffusion-webui/models/sd_xl_refiner_1.0"
VAE_PATH = f"{MODEL_BASE}/vae"
OUTPUT_DIR = "/workspace/stable-diffusion-webui/outputs"

# === Load once ===
print("🔹 Loading models...")

co = cohere.Client("V2CGAFrhoVGvMmqOwhxrH7uimhIgR2S8787ODJD4")
controlnet_canny = ControlNetModel.from_pretrained(CONTROLNET_CANNY_DIR, torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained(CONTROLNET_DEPTH_DIR, torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(VAE_PATH, torch_dtype=torch.float16, variant="fp16")
tokenizer = CLIPTokenizer.from_pretrained(f"{MODEL_BASE}/tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(f"{MODEL_BASE}/tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(f"{MODEL_BASE}/text_encoder", torch_dtype=torch.float16)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(f"{MODEL_BASE}/text_encoder_2", torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained(f"{MODEL_BASE}/unet", torch_dtype=torch.float16, variant="fp16")
scheduler = EulerDiscreteScheduler.from_pretrained(f"{MODEL_BASE}/scheduler")

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_BASE,
    controlnet=[controlnet_canny, controlnet_depth],
    unet=unet, vae=vae,
    text_encoder=text_encoder, text_encoder_2=text_encoder_2,
    tokenizer=tokenizer, tokenizer_2=tokenizer_2,
    scheduler=scheduler,
    torch_dtype=torch.float16
)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()

refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    REFINER_DIR,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
)
refiner.enable_model_cpu_offload()
refiner.enable_attention_slicing()
refiner.enable_vae_tiling()

# === Function ===
def run_kitchen_generation(transcript, input_image_path):
    # 🧠 Prompt Engineering
    print("🎯 Extracting SDXL Prompt from transcript...")
    system_prompt = """
You are a kitchen prompt expert. Your job is to read a conversation transcript and generate a short, visually rich prompt that can be used with Stable Diffusion XL (SDXL).

Rules:
- Format: "[Style] kitchen, [materials], [layout/furniture], [lighting], photorealistic"
- Do NOT use vague words like "beautiful" or "cozy"
- Focus on visual elements only (e.g. colors, materials, objects, layout)
- Max 77 tokens
- No user names, no quotes, no assistant text
"""
    full_prompt = system_prompt + "\n\nTranscript:\n" + transcript + "\n\nOutput:"
    response = co.generate(model="command-r-plus", prompt=full_prompt, max_tokens=100, temperature=0.6)
    PROMPT = response.generations[0].text.strip()
    print(f"✅ Prompt: {PROMPT}")
    NEG_PROMPT = "blurry, low quality, distorted, granite countertops, generic kitchen, plain design"

    # 🖼 Image Preprocessing
    init_image = load_image(input_image_path).convert("RGB")
    width, height = init_image.size
    gray = cv2.cvtColor(np.array(init_image), cv2.COLOR_RGB2GRAY)
    canny = cv2.Canny(gray, 100, 200)
    canny_image = Image.fromarray(cv2.cvtColor(canny, cv2.COLOR_GRAY2RGB))
    depth = init_image.convert("L")
    depth_array = cv2.equalizeHist(np.array(depth))
    depth_image = Image.fromarray(np.stack([depth_array]*3, axis=-1))

    # 🗂 Save debug images
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_run_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}")
    os.makedirs(output_run_dir, exist_ok=True)
    init_image.save(os.path.join(output_run_dir, "01_original.png"))
    canny_image.save(os.path.join(output_run_dir, "02_canny.png"))
    depth_image.save(os.path.join(output_run_dir, "03_depth.png"))

    # 🎨 Generation
    print("🎨 Generating with ControlNet...")
    result = pipe(
        prompt=PROMPT,
        negative_prompt=NEG_PROMPT,
        image=[canny_image, depth_image],
        controlnet_conditioning_scale=[0.5, 0.2],
        num_inference_steps=40,
        guidance_scale=9.0,
        width=width,
        height=height
    ).images[0]
    result_path = os.path.join(output_run_dir, "04_result.png")
    result.save(result_path)

    # ✨ Refinement
    print("✨ Refining with SDXL Refiner...")
    refined = refiner(
        prompt=PROMPT,
        negative_prompt=NEG_PROMPT,
        image=result,
        strength=0.3,
        guidance_scale=7.5,
        num_inference_steps=25
    ).images[0]
    refined_path = os.path.join(output_run_dir, "05_refined.png")
    refined.save(refined_path)

    print(f"✅ Refined image saved to: {refined_path}")
    return refined_path