In [None]:
import os
import cv2
import numpy as np
from PIL import Image
from diffusers import (
    StableDiffusionXLControlNetPipeline,
    ControlNetModel,
    AutoencoderKL,
    EulerDiscreteScheduler
)
from diffusers.utils import load_image
from datetime import datetime
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPConfig
import torch
torch.cuda.empty_cache()
from transformers import CLIPTextModelWithProjection
import os, gc, torch
import cohere
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

# Set environment variables
os.environ["TMPDIR"] = "/workspace/tmp"
os.environ["XFORMERS_DISABLED"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"

# === Paths ===
CONTROLNET_CANNY_DIR = "/workspace/stable-diffusion-webui/models/ControlNet/controlnet-canny-sdxl"
CONTROLNET_DEPTH_DIR = "/workspace/stable-diffusion-webui/models/ControlNet/controlnet-depth-sdxl"
MODEL_BASE = "/workspace/stable-diffusion-webui/models/sd_xl_base_1.0"
VAE_PATH = f"{MODEL_BASE}/vae"

INPUT_IMAGE = "/workspace/stable-diffusion-webui/images/depositphotos_5894774-stock-photo-empty-kitchen.jpg"
OUTPUT_DIR = "/workspace/stable-diffusion-webui/outputs"

# === Prompt ===
co = cohere.Client("V2CGAFrhoVGvMmqOwhxrH7uimhIgR2S8787ODJD4")

transcript = """
Leon wants a cozy, beachy kitchen with an L-shaped layout. He loves blue and beige colors, natural wood elements, and is considering a beachy backsplash with mosaic tiles or sea glass. He's also interested in light fixtures made of natural materials like rattan or driftwood.
"""

# STEP 3: Build smart extraction prompt for LLM
system_prompt = """
You are a kitchen prompt expert. Your job is to read a conversation transcript and generate a short, visually rich prompt that can be used with Stable Diffusion XL (SDXL).

Rules:
- Format: "[Style] kitchen, [materials], [layout/furniture], [lighting], photorealistic"
- Do NOT use vague words like "beautiful" or "cozy"
- Focus on visual elements only (e.g. colors, materials, objects, layout)
- Max 77 tokens
- No user names, no quotes, no assistant text
"""

full_prompt = system_prompt + "\n\nTranscript:\n" + transcript + "\n\nOutput:"

# STEP 4: Send to Cohere
response = co.generate(
    model="command-r-plus",
    prompt=full_prompt,
    max_tokens=100,
    temperature=0.6
)

# STEP 5: Print result
PROMPT = response.generations[0].text.strip()
print("\n🎯 Final SDXL Prompt:")
print(PROMPT)
NEGATIVE_PROMPT = "blurry, low quality, distorted, granite countertops, generic kitchen, plain design"

# === Prepare output ===
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_run_dir = os.path.join(OUTPUT_DIR, f"combined_run_{timestamp}")
os.makedirs(output_run_dir, exist_ok=True)

# === Load ControlNet Models ===
print("🔹 Loading ControlNet models...")
controlnet_canny = ControlNetModel.from_pretrained(CONTROLNET_CANNY_DIR, torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained(CONTROLNET_DEPTH_DIR, torch_dtype=torch.float16)

# === Load VAE ===
print("🔹 Loading custom VAE...")
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained(
    "/workspace/stable-diffusion-webui/models/sd_xl_base_1.0/vae",
    torch_dtype=torch.float16,
    variant="fp16"
)



# === Load Tokenizers & Text Encoders ===
print("🔹 Loading tokenizers and encoders...")
tokenizer = CLIPTokenizer.from_pretrained(f"{MODEL_BASE}/tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(f"{MODEL_BASE}/tokenizer_2")

text_encoder = CLIPTextModel.from_pretrained(f"{MODEL_BASE}/text_encoder", torch_dtype=torch.float16)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
    f"{MODEL_BASE}/text_encoder_2",
    torch_dtype=torch.float16
)

# === Load SDXL Pipeline ===
print("🔹 Loading SDXL pipeline...")
from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
    f"{MODEL_BASE}/unet",
    torch_dtype=torch.float16,
    variant="fp16"
)
from diffusers import EulerDiscreteScheduler

scheduler = EulerDiscreteScheduler.from_pretrained(
    "/workspace/stable-diffusion-webui/models/sd_xl_base_1.0/scheduler"
)

from diffusers import StableDiffusionXLControlNetPipeline

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_BASE,
    controlnet=[controlnet_canny, controlnet_depth],
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    #variant="fp16"
)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
from diffusers import StableDiffusionXLImg2ImgPipeline

print("🔹 Loading SDXL Refiner pipeline...")

REFINER_DIR = "/workspace/stable-diffusion-webui/models/sd_xl_refiner_1.0"

from diffusers import StableDiffusionXLImg2ImgPipeline

print("🔹 Loading SDXL Refiner pipeline...")
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "/workspace/stable-diffusion-webui/models/sd_xl_refiner_1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
)
refiner.enable_model_cpu_offload()
refiner.enable_attention_slicing()
refiner.enable_vae_tiling()
refiner.enable_model_cpu_offload()  # keep memory usage low

# === Prepare input image ===
print("🔹 Processing input image...")
init_image = load_image(INPUT_IMAGE).convert("RGB")
width, height = init_image.size

# === Generate Canny map ===
print("🔹 Creating Canny edges...")
gray = cv2.cvtColor(np.array(init_image), cv2.COLOR_RGB2GRAY)
canny = cv2.Canny(gray, 100, 200)
canny_image = Image.fromarray(cv2.cvtColor(canny, cv2.COLOR_GRAY2RGB))

# === Simulated depth map ===
print("🔹 Creating Depth map...")
depth = init_image.convert("L")
depth_array = cv2.equalizeHist(np.array(depth))
depth_image = Image.fromarray(np.stack([depth_array] * 3, axis=-1))

# === Save debug images ===
init_image.save(os.path.join(output_run_dir, "01_original.png"))
canny_image.save(os.path.join(output_run_dir, "02_canny.png"))
depth_image.save(os.path.join(output_run_dir, "03_depth.png"))

# === Generate result ===
print("🎨 Generating design with Canny + Depth ControlNet...")
# Truncate prompt to 77 tokens
# 🔐 Truncate prompt to avoid token overflow (max 77 tokens)
tokenized = tokenizer(PROMPT, return_tensors="pt", truncation=True, max_length=77)
PROMPT = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
print(f"🧠 Truncated prompt token count: {len(tokenized['input_ids'][0])}")

result = pipe(
    prompt=PROMPT,
    negative_prompt=NEGATIVE_PROMPT,
    image=[canny_image, depth_image],
    controlnet_conditioning_scale=[0.5, 0.2],
    num_inference_steps=40,
    guidance_scale=9.0,
    width=width,
    height=height
).images[0]

# === Save result ===
# === Save result ===
result_path = os.path.join(output_run_dir, "04_result.png")
result.save(result_path)
print(f"✅ Base result saved to: {result_path}")

# === Refine result with SDXL Refiner ===
print("✨ Refining the image with SDXL Refiner...")

refined = refiner(
    prompt=PROMPT,
    negative_prompt=NEGATIVE_PROMPT,
    image=result,
    strength=0.3,
    guidance_scale=7.5,
    num_inference_steps=25
).images[0]

# === Save refined result ===
refined_path = os.path.join(output_run_dir, "05_refined.png")
refined.save(refined_path)
print(f"✅ Refined result saved to: {refined_path}")