In [None]:
import gradio as gr
import json
import requests
import time
import base64
from pathlib import Path
from PIL import Image
import io
import numpy as np
import umap
import torch
from transformers import CLIPProcessor, CLIPModel

# ####################
# ComfyUI & Prompt Logic
# ####################

COMFYUI_API_PATH = r"E:\JKU_LINZ\Winter_25\MS_Thesis\Text2Image\text2imagegen\First_trials\comfyuiapi\Image_gen.json"
COMFYUI_SERVER = "http://127.0.0.1:8188"

BATCH_HISTORY = []
MAX_BATCHES = 12
SELECTED_IMAGE_INDEX = None

CLIP_HISTORY = []  # list of dicts with "embeddings": [np.ndarray or None, ...]

ATTRIBUTE_COLOR_MAP = {
    "cute":       "#FBBF24",
    "happy":      "#22C55E",
    "ugly":       "#EF4444",
    "sad":        "#3B82F6",
    "joyful":     "#06B6D4",
    "miserable":  "#7C3AED",
    "old":        "#A16207",
    "young":      "#86EFAC",
    "rich":       "#F59E0B",
    "poor":       "#6B7280",
    "friendly":   "#14B8A6",
    "aggressive": "#DC2626",
    "sci-fi":     "#818CF8",
    "mechanical": "#0EA5E9",
}

def get_attr_color(attr: str) -> str:
    return ATTRIBUTE_COLOR_MAP.get((attr or "").lower(), "#AAAAAA")

# ── Scale: pushes probabilities polar.
# Math: p = sigmoid(SCALE * delta), where delta = cosine_gap between pos/neg poles.
# Typical CLIP cosine gap is 0.03-0.08 for clear cases.
# SCALE=60 → delta=0.05 gives p=0.95; delta=0.08 gives p=0.99
BIPOLAR_USE_CLIP_LOGIT_SCALE = False
BIPOLAR_EXTRA_SCALE = 60.0  # ← raised from 20 → pushes clear cases to 0.90-0.99

# ── Prompt ensembling templates 
TEXT_TEMPLATES = [
    "{desc}",
    "a photo of {desc}",
    "an illustration of {desc}",
    "a close-up photo of {desc}",
    "a high quality realistic photo of {desc}",
    "a detailed portrait of {desc}",
]

ANIMAL_CONTEXT_TEMPLATES: dict = {
    "cat":    ["a {desc} cat",    "a photo of a {desc} cat",    "a {desc} kitten",    "portrait of a {desc} cat"],
    "dog":    ["a {desc} dog",    "a photo of a {desc} dog",    "a {desc} puppy",     "portrait of a {desc} dog"],
    "lion":   ["a {desc} lion",   "a photo of a {desc} lion",   "a {desc} wild lion", "portrait of a {desc} lion"],
    "tiger":  ["a {desc} tiger",  "a photo of a {desc} tiger",  "a {desc} bengal tiger"],
    "bird":   ["a {desc} bird",   "a photo of a {desc} bird",   "a {desc} parrot or eagle"],
    "dragon": ["a {desc} dragon", "a fantasy {desc} dragon",    "an illustration of a {desc} dragon"],
    "robot":  ["a {desc} robot",  "a mechanical {desc} robot",  "a photo of a {desc} robot"],
}
DEFAULT_ANIMAL_TEMPLATES: list = [
    "a {desc} animal",
    "a photo of a {desc} animal",
    "a {desc} creature",
    "portrait of a {desc} creature",
]

# ── Bipolar attribute pairs 
# CLIP detects VISUAL features in the image, so descriptions must target things
# visible in a 512x512 portrait: fur texture, eye clarity, facial expression,
# muzzle colour, body posture. Abstract concepts like "feelings" don't score well.
BIPOLAR_ATTRIBUTES = {

    # pair 0: CUTENESS 
    "appearance_cuteness": {
        "positive": {
            "label": "cute",
            "description": (
                "extremely adorable baby animal, disproportionately huge round glassy eyes, "
                "tiny soft pink button nose, ultra-fluffy plush downy coat, "
                "chubby rounded puffy cheeks, small compact round body, "
                "irresistibly innocent sweet gentle expression"
            ),
        },
        "negative": {
            "label": "ugly",
            "description": (
                "grotesque horribly deformed animal, severely asymmetric scarred face, "
                "completely bald patchy infected scabby skin, deep sunken hollow dull eyes, "
                "crooked broken yellow teeth, festering sores, diseased emaciated body"
            ),
        },
    },

    # pair 1: HAPPINESS — facial expression CLIP can read from portraits
    "emotion_valence": {
        "positive": {
            "label": "happy",
            "description": (
                "visibly happy animal with wide open bright eyes, "
                "mouth open showing teeth in a grin, tongue lolling out joyfully, "
                "ears fully pricked forward, soft relaxed face muscles, "
                "warm inviting alert expression"
            ),
        },
        "negative": {
            "label": "sad",
            "description": (
                "visibly sad animal, downcast half-closed glistening watery eyes, "
                "mouth tightly shut with corners pulled down, ears pinned flat, "
                "furrowed brow, heavy drooping head, deeply unhappy sorrowful face"
            ),
        },
    },

    # pair 2: ALERTNESS/ENERGY — focus on FACIAL alertness, not body posture.
    # "Leaping mid-air" is invisible in a sitting portrait; bright vs dull EYES work.
    "emotion_intensity": {
        "positive": {
            "label": "joyful",
            "description": (
                "extremely alert playful energetic animal, very wide sparkling bright eyes, "
                "ears fully erect and pointing forward, open mouth, "
                "animated lively engaged expression, whiskers spread wide, "
                "intense curious attentive gaze, visibly excited"
            ),
        },
        "negative": {
            "label": "miserable",
            "description": (
                "completely lethargic depressed suffering animal, "
                "half-closed glazed dull lifeless eyes, ears flat against skull, "
                "tightly closed tense mouth, hunched withdrawn body, "
                "vacant hollow stare, visibly exhausted and defeated"
            ),
        },
    },

    # pair 3: AGE — must use FACE-VISIBLE cues that appear even in fantasy portraits.
    # Key CLIP-visible signals: muzzle colour (grey/white = old), eye clarity,
    # coat texture (smooth = young; coarse thin = old), facial wrinkles.
    "age": {
        "positive": {
            "label": "young",
            "description": (
                "very young juvenile animal, completely smooth clean glossy coat, "
                "proportionally enormous bright clear sparkling eyes, "
                "tiny soft pink nose and small round muzzle, "
                "perfectly clean pure-coloured fur with no grey or white patches, "
                "soft baby-like rounded facial features, fresh vibrant appearance"
            ),
        },
        "negative": {
            "label": "old",
            "description": (
                "very elderly senior animal, conspicuously white or grey fur "
                "covering entire muzzle chin and forehead, "
                "deeply cloudy milky opaque eyes with sunken sockets, "
                "coarse sparse thinning dull coat, pronounced deep facial wrinkles, "
                "thick prominent whisker pads, heavy-lidded tired aged expression"
            ),
        },
    },

    # pair 4: STATUS — jewel collars / grooming vs emaciated stray
    "status": {
        "positive": {
            "label": "rich",
            "description": (
                "pampered prize show animal, perfectly groomed immaculate glossy coat, "
                "ornate diamond-studded jewelled collar, bright clean white teeth, "
                "posed on velvet cushion in opulent luxury setting, "
                "well-fed plump healthy shiny appearance"
            ),
        },
        "negative": {
            "label": "poor",
            "description": (
                "emaciated neglected stray street animal, heavily matted tangled filthy fur, "
                "no collar, prominent visible ribs and hip bones, "
                "mud-caked cracked paws, open sores, scruffy unkempt appearance"
            ),
        },
    },

    # pair 5: SOCIAL BEHAVIOUR
    "social_behavior": {
        "positive": {
            "label": "friendly",
            "description": (
                "gentle friendly approachable animal, soft warm inviting eyes, "
                "relaxed open mouth with gentle smile, tail fully raised and wagging, "
                "head tilted sideways, calm non-threatening completely relaxed posture"
            ),
        },
        "negative": {
            "label": "aggressive",
            "description": (
                "ferociously aggressive attack-ready animal, fully bared razor-sharp fangs, "
                "deeply wrinkled snarling nose, wide open threatening jaw, "
                "tensed crouched attack stance, ears pinned flat, raised hackles, "
                "hostile intense deadly glare"
            ),
        },
    },
}

# ── Flat label → pair info lookup 
_LABEL_TO_BIPOLAR: dict = {}
for _pk, _pair in BIPOLAR_ATTRIBUTES.items():
    for _pol in ("positive", "negative"):
        _opp = "negative" if _pol == "positive" else "positive"
        _lbl = _pair[_pol]["label"]
        _LABEL_TO_BIPOLAR[_lbl] = {
            "description":    _pair[_pol]["description"],
            "pair_key":       _pk,
            "polarity":       _pol,
            "is_positive":    (_pol == "positive"),
            "opposite_label": _pair[_opp]["label"],
            "opposite_desc":  _pair[_opp]["description"],
        }

CLIP_ATTRIBUTE_DESCRIPTIONS: dict = {
    "sci-fi":     "a sci-fi futuristic cybernetic creature, glowing implants, neon lighting",
    "mechanical": "a fully mechanical robotic creature, exposed metal gears, rigid joints",
}
for _pair in BIPOLAR_ATTRIBUTES.values():
    for _pole in ("positive", "negative"):
        CLIP_ATTRIBUTE_DESCRIPTIONS[_pair[_pole]["label"]] = _pair[_pole]["description"]

COMPARISON_MODIFIERS = [
    "cute", "happy", "joyful", "young", "rich", "friendly",
    "ugly", "sad",  "miserable", "old",  "poor", "aggressive",
]

# ── Art styles optimised for 512×512 highly-detailed outputs 
# Each style adds quality boosters that help SD squeeze maximum detail at 512×512.
ART_STYLES = {
    "None": "",
    "Fantasy": (
        "fantasy art, magical atmosphere, ethereal lighting, "
        "masterpiece, best quality, highly detailed, sharp focus, intricate details"
    ),
    "Ultra Realistic": (
        "ultra realistic, photorealistic, hyperrealistic, "
        "masterpiece, best quality, 8k resolution, RAW photo, "
        "highly detailed, sharp focus, HDR, subsurface scattering"
    ),
    "Watercolor": (
        "watercolor painting, soft delicate colors, artistic, painterly, "
        "masterpiece, best quality, highly detailed brushwork, wet-on-wet technique"
    ),
    "Sketch": (
        "detailed pencil sketch, hand drawn, fine linework, crosshatching, "
        "masterpiece, best quality, highly detailed, sharp lines, professional illustration"
    ),
    "Anime": (
        "anime style, manga illustration, vibrant saturated colors, cel shaded, "
        "masterpiece, best quality, highly detailed, sharp, studio quality animation"
    ),
    "Oil Painting": (
        "oil painting, classical fine art, rich impasto brush strokes, "
        "masterpiece, best quality, highly detailed, Renaissance style, dramatic lighting"
    ),
    "Digital Art": (
        "digital art, concept art, trending on artstation, "
        "masterpiece, best quality, highly detailed, sharp, volumetric lighting, 4k"
    ),
    "Cartoon": (
        "cartoon style, stylized illustration, vibrant clean colors, "
        "masterpiece, best quality, highly detailed, bold clean lines, Disney quality"
    ),
    "Cinematic": (
        "cinematic photography, movie still, anamorphic lens, bokeh, "
        "masterpiece, best quality, highly detailed, dramatic studio lighting, 4k"
    ),
    "Portrait": (
        "professional portrait photography, studio lighting, shallow depth of field, "
        "masterpiece, best quality, highly detailed, sharp focus on face, "
        "DSLR 85mm lens, natural skin texture"
    ),
}

BASE_ATTRIBUTES = [
    "cute", "happy", "ugly", "sad", "joyful", "sci-fi", "mechanical",
    "miserable", "old", "young", "rich", "poor",
    "friendly", "aggressive", "Custom..."
]

BASE_SUBJECTS = ["cat", "dog", "bird", "robot", "dragon", "tiger", "lion", "Custom..."]

# ── Default negative prompt optimised for 512×512 detail 
DEFAULT_NEGATIVE_PROMPT = (
    "text, watermark, blurry, out of focus, low quality, low resolution, "
    "jpeg artifacts, noise, grain, overexposed, underexposed, "
    "deformed, disfigured, bad anatomy, extra limbs, missing limbs, "
    "ugly, duplicate, morbid, mutilated, poorly drawn face"
)

# ####################
# CLIP model (image + text)
# ####################

CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
_clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
_clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
_clip_model.eval()


####################
# image_vec  = _clip_model.get_image_features(PIL)  → shared CLIP space
# text_vec   = _clip_model.get_text_features(text)  → shared CLIP space
# Cosine similarity between them 
####################

def encode_image_with_clip(pil_image: Image.Image) -> np.ndarray:
    rgb = pil_image.convert("RGB")
    inputs = _clip_processor(images=rgb, return_tensors="pt")
    with torch.no_grad():
        img_features = _clip_model.get_image_features(**inputs)
    vec = img_features.cpu().numpy().astype(np.float32).reshape(-1)
    vec /= (np.linalg.norm(vec) + 1e-8)
    return vec


def text_modifiers_to_clip_vectors(modifiers):
    if not modifiers:
        return np.zeros((0, 512), dtype=np.float32)
    descriptions = [CLIP_ATTRIBUTE_DESCRIPTIONS.get(m.lower(), m) for m in modifiers]
    inputs = _clip_processor(
        text=descriptions, return_tensors="pt",
        padding=True, truncation=True, max_length=77)
    with torch.no_grad():
        text_features = _clip_model.get_text_features(**inputs)
    text_vecs = text_features.cpu().numpy().astype(np.float32)
    norms = np.linalg.norm(text_vecs, axis=1, keepdims=True) + 1e-8
    return text_vecs / norms


def _encode_text_batch(texts: list, cache: dict) -> None:
    missing = [t for t in texts if t not in cache]
    if not missing:
        return
    inp = _clip_processor(text=missing, return_tensors="pt",
                          padding=True, truncation=True, max_length=77)
    with torch.no_grad():
        feats = _clip_model.get_text_features(**inp)
    vecs = feats.cpu().numpy().astype(np.float32)
    vecs /= (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-8)
    for t, v in zip(missing, vecs):
        cache[t] = v


def _encode_ensemble(base_desc: str, cache: dict) -> np.ndarray:
    variants = [tpl.format(desc=base_desc) for tpl in TEXT_TEMPLATES]
    _encode_text_batch(variants, cache)
    mat = np.stack([cache[v] for v in variants])
    avg = mat.mean(axis=0)
    avg /= (np.linalg.norm(avg) + 1e-8)
    return avg


def _encode_ensemble_for_subject(base_desc: str, subject: str, cache: dict) -> np.ndarray:
    species_tpls = ANIMAL_CONTEXT_TEMPLATES.get(
        (subject or "").lower(), DEFAULT_ANIMAL_TEMPLATES
    )
    all_tpls = list(dict.fromkeys(species_tpls + TEXT_TEMPLATES))
    variants = [tpl.format(desc=base_desc) for tpl in all_tpls]
    _encode_text_batch(variants, cache)
    mat = np.stack([cache[v] for v in variants])
    avg = mat.mean(axis=0)
    avg /= (np.linalg.norm(avg) + 1e-8)
    return avg


def get_effective_value(dropdown_value: str, custom_value: str):
    custom_value = (custom_value or "").strip()
    if dropdown_value == "Custom...":
        return custom_value if custom_value else ""
    return dropdown_value


def load_workflow_template():
    try:
        with open(COMFYUI_API_PATH, "r") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Warning: Workflow file not found at {COMFYUI_API_PATH}")
        return {}


def combine_prompt_with_style(base_prompt, art_style, art_style_weight):
    if art_style and art_style != "None":
        style_suffix = ART_STYLES.get(art_style, "")
        if style_suffix:
            w = round(float(art_style_weight), 1)
            return f"{base_prompt}, ({style_suffix}:{w})"
    return base_prompt


def build_base_prompt(attribute, attribute_weight, subject):
    if not subject:
        subject = "subject"
    if not attribute:
        return f"a {subject}"
    w = round(float(attribute_weight), 1)
    return f"a ({attribute}:{w}) {subject}"


def update_workflow_parameters(
    workflow, positive_prompt, negative_prompt, seed, control_mode,
    steps, cfg, sampler, scheduler, denoise, width, height, batch_size
):
    if not workflow:
        return None
    try:
        for node_id, node_data in workflow.items():
            if node_data.get("class_type") == "CLIPTextEncode":
                if "inputs" in node_data and positive_prompt:
                    node_data["inputs"]["text"] = positive_prompt
                    break
        clip_nodes = [k for k, v in workflow.items()
                      if v.get("class_type") == "CLIPTextEncode"]
        if len(clip_nodes) > 1 and negative_prompt:
            workflow[clip_nodes[1]]["inputs"]["text"] = negative_prompt
        for node_id, node_data in workflow.items():
            if node_data.get("class_type") == "KSampler":
                inputs = node_data.get("inputs", {})
                inputs["seed"] = seed
                inputs["steps"] = steps
                inputs["cfg"] = cfg
                inputs["sampler_name"] = sampler
                inputs["scheduler"] = scheduler
                inputs["denoise"] = denoise
        for node_id, node_data in workflow.items():
            if node_data.get("class_type") == "EmptyLatentImage":
                inputs = node_data.get("inputs", {})
                inputs["width"] = width
                inputs["height"] = height
                inputs["batch_size"] = batch_size
        return workflow
    except Exception as e:
        print(f"Error updating workflow: {e}")
        return None


def get_images_from_comfy(prompt_id, server_address):
    try:
        history_url = f"{server_address}/history/{prompt_id}"
        response = requests.get(history_url)
        if response.status_code != 200:
            return []
        history = response.json()
        if prompt_id not in history:
            return []
        outputs = history[prompt_id].get("outputs", {})
        images = []
        for node_id, node_output in outputs.items():
            if "images" in node_output:
                for image_info in node_output["images"]:
                    filename = image_info["filename"]
                    subfolder = image_info.get("subfolder", "")
                    folder_type = image_info.get("type", "output")
                    image_url = f"{server_address}/view"
                    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
                    img_response = requests.get(image_url, params=params)
                    if img_response.status_code == 200:
                        img = Image.open(io.BytesIO(img_response.content))
                        images.append(img)
        return images
    except Exception as e:
        print(f"Error retrieving images: {e}")
        return []


def get_clip_embeddings_from_comfy(prompt_id, server_address):
    try:
        history_url = f"{server_address}/history/{prompt_id}"
        response = requests.get(history_url)
        if response.status_code != 200:
            return []
        history = response.json()
        if prompt_id not in history:
            return []
        outputs = history[prompt_id].get("outputs", {})
        clip_embs = []
        for node_id, node_output in outputs.items():
            class_type = node_output.get("class_type") or node_output.get("type")
            if class_type == "CLIPVisionEncode":
                data = node_output.get("embeds") or node_output.get("samples") or []
                for d in data:
                    emb = np.array(d, dtype=np.float32)
                    clip_embs.append(emb)
        return clip_embs
    except Exception as e:
        print("Error retrieving CLIP embeddings:", e)
        return []


def wait_for_completion(prompt_id, server_address, timeout=300):
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            history_url = f"{server_address}/history/{prompt_id}"
            response = requests.get(history_url)
            if response.status_code == 200:
                history = response.json()
                if prompt_id in history:
                    return True
            time.sleep(1)
        except Exception as e:
            print(f"Error checking status: {e}")
            time.sleep(1)
    return False


def image_to_vector_from_clip(clip_embedding: np.ndarray) -> np.ndarray:
    vec = np.asarray(clip_embedding, dtype=np.float32).reshape(-1)
    norm = np.linalg.norm(vec) + 1e-8
    return vec / norm


def compute_2d_layout_smart(images, batch_ids, subjects, clip_embeddings=None):
    if not images or len(images) == 0:
        return np.zeros((0, 2))
    if clip_embeddings is not None and len(clip_embeddings) == len(images):
        visual_feats = np.stack(
            [image_to_vector_from_clip(emb) for emb in clip_embeddings], axis=0)
    else:
        visual_feats = np.stack(
            [np.asarray(im.convert("RGB").resize((64, 64)), dtype=np.float32).reshape(-1) / 255.0
             for im in images], axis=0)
    if visual_feats.shape[0] < 3:
        xs = np.linspace(0.2, 0.8, visual_feats.shape[0])
        ys = np.ones_like(xs) * 0.5
        return np.stack([xs, ys], axis=1)
    unique_batches = list(set(batch_ids))
    batch_features = np.zeros((len(batch_ids), len(unique_batches)))
    for i, batch_id in enumerate(batch_ids):
        batch_features[i, unique_batches.index(batch_id)] = 1.0
    unique_subjects = list(set(subjects))
    subject_features = np.zeros((len(subjects), len(unique_subjects)))
    for i, subject in enumerate(subjects):
        subject_features[i, unique_subjects.index(subject)] = 1.0
    combined_features = np.concatenate(
        [batch_features * 25.0, subject_features * 20.0, visual_feats * 0.05], axis=1)
    n_neighbors = min(8, len(images) - 1)
    reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=n_neighbors,
                        min_dist=0.3, spread=3.0, metric="euclidean")
    emb = reducer.fit_transform(combined_features)
    for batch in set(batch_ids):
        batch_indices = [i for i, b in enumerate(batch_ids) if b == batch]
        if len(batch_indices) > 1:
            batch_center = emb[batch_indices].mean(axis=0)
            for idx in batch_indices:
                emb[idx] = batch_center + np.random.uniform(-0.015, 0.015, 2)
    emb = 0.1 + 0.8 * (emb - emb.min(0)) / (emb.max(0) - emb.min(0) + 1e-8)
    return resolve_overlaps_smart(emb, batch_ids, subjects)


def resolve_overlaps_smart(positions, batch_ids, subjects, min_distance=90):
    positions = positions.copy()
    n = len(positions)
    if n < 2:
        return positions
    canvas_width, canvas_height = 1000, 800
    positions[:, 0] *= canvas_width
    positions[:, 1] *= canvas_height
    unique_subjects = list(set(subjects))
    subject_to_idx = {subj: [] for subj in unique_subjects}
    for i, subj in enumerate(subjects):
        subject_to_idx[subj].append(i)
    for _ in range(150):
        moved = False
        if len(unique_subjects) > 1:
            cluster_centers = {subj: positions[idxs].mean(axis=0)
                               for subj, idxs in subject_to_idx.items()}
            subs = list(cluster_centers.keys())
            for i in range(len(subs)):
                for j in range(i + 1, len(subs)):
                    sub_a, sub_b = subs[i], subs[j]
                    diff = cluster_centers[sub_a] - cluster_centers[sub_b]
                    dist = np.linalg.norm(diff)
                    if 0 < dist < 150:
                        force = (diff / dist) * (150 - dist) * 0.5
                        for idx in subject_to_idx[sub_a]:
                            positions[idx] += force; moved = True
                        for idx in subject_to_idx[sub_b]:
                            positions[idx] -= force; moved = True
        for i in range(n):
            forces = np.zeros(2)
            for j in range(n):
                if i == j: continue
                diff = positions[i] - positions[j]
                dist = np.linalg.norm(diff)
                if dist <= 0: continue
                if batch_ids[i] == batch_ids[j]:       eff = min_distance * 0.5
                elif subjects[i] == subjects[j]:        eff = min_distance * 1.2
                else:                                   eff = min_distance * 2.0
                if dist < eff:
                    forces += (diff / dist) * ((eff - dist) / eff) * 12; moved = True
            positions[i] += forces * 0.4
            positions[i, 0] = np.clip(positions[i, 0], 80, canvas_width - 80)
            positions[i, 1] = np.clip(positions[i, 1], 80, canvas_height - 80)
        if not moved: break
    positions[:, 0] /= canvas_width
    positions[:, 1] /= canvas_height
    return positions


def img_to_base64(img: Image.Image):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

# ####################
# Similarity + radar chart
# ####################


def compute_attribute_similarity(
    selected_image,
    selected_image_clip,   
    selected_attribute,
    comparison_modifiers,
    subject: str = "",
):
    """
    Bipolar Softmax scoring.
      p(pos) = sigmoid(SCALE * (cos(img_vec, t_pos) - cos(img_vec, t_neg)))
      p(neg) = 1 - p(pos)
    img_vec  = CLIP image encoder  (get_image_features)
    t_pos/neg = CLIP text encoder  (get_text_features)
    Both live in the shared CLIP embedding space. ✓
    """
    if selected_image is None or not comparison_modifiers:
        sims, pair_pos = {}, {}
        for mod in comparison_modifiers:
            info = _LABEL_TO_BIPOLAR.get(mod.lower())
            if info:
                pk = info["pair_key"]
                if pk not in pair_pos:
                    pair_pos[pk] = float(np.random.uniform(0.35, 0.65))
                sims[mod] = pair_pos[pk] if info["is_positive"] else 1.0 - pair_pos[pk]
            else:
                sims[mod] = float(np.random.uniform(0.35, 0.65))
        return sims

    # Encode PIL image → CLIP image space
    img = encode_image_with_clip(selected_image)

    scale = float(BIPOLAR_EXTRA_SCALE)
    if BIPOLAR_USE_CLIP_LOGIT_SCALE:
        try:
            scale *= float(_clip_model.logit_scale.exp().detach().cpu().item())
        except Exception:
            pass

    # Pre-encode all descriptions
    cache: dict = {}
    needed_descs = []
    for mod in comparison_modifiers:
        info = _LABEL_TO_BIPOLAR.get(mod.lower())
        if info:
            needed_descs += [info["description"], info["opposite_desc"]]
        else:
            d = CLIP_ATTRIBUTE_DESCRIPTIONS.get(mod.lower(), mod)
            needed_descs += [d, f"not {d}"]

    _species_tpls = ANIMAL_CONTEXT_TEMPLATES.get((subject or "").lower(), DEFAULT_ANIMAL_TEMPLATES)
    _all_tpls = list(dict.fromkeys(_species_tpls + TEXT_TEMPLATES))
    all_variants = list(dict.fromkeys(
        tpl.format(desc=d)
        for d in dict.fromkeys(needed_descs)
        for tpl in _all_tpls
    ))
    _encode_text_batch(all_variants, cache)

    pair_pos_prob: dict = {}
    sims: dict = {}

    for mod in comparison_modifiers:
        m = mod.lower()
        info = _LABEL_TO_BIPOLAR.get(m)
        if info:
            pk = info["pair_key"]
            if pk not in pair_pos_prob:
                pos_desc = info["description"] if info["is_positive"] else info["opposite_desc"]
                neg_desc = info["opposite_desc"] if info["is_positive"] else info["description"]
                vpos = _encode_ensemble_for_subject(pos_desc, subject, cache)
                vneg = _encode_ensemble_for_subject(neg_desc, subject, cache)
                s1 = float(vpos @ img)
                s2 = float(vneg @ img)
                pair_pos_prob[pk] = float(1.0 / (1.0 + np.exp(-scale * (s1 - s2))))
            ppos = pair_pos_prob[pk]
            sims[mod] = ppos if info["is_positive"] else (1.0 - ppos)
        else:
            d = CLIP_ATTRIBUTE_DESCRIPTIONS.get(m, mod)
            v1 = _encode_ensemble_for_subject(d, subject, cache)
            v2 = _encode_ensemble_for_subject(f"not {d}", subject, cache)
            delta = float(v1 @ img) - float(v2 @ img)
            sims[mod] = float(1.0 / (1.0 + np.exp(-scale * delta)))

    return sims


def generate_radar_chart(
    selected_image,
    selected_image_clip,
    selected_attribute,
    comparison_modifiers,
    subject: str = "",
):
    if selected_image is None or not comparison_modifiers:
        return ("<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>"
                "Select an image to see attribute similarity</div>")

    similarities = compute_attribute_similarity(
        selected_image, selected_image_clip,
        selected_attribute, comparison_modifiers, subject=subject)

    n_axes = len(comparison_modifiers)
    angles = [2 * np.pi * i / n_axes for i in range(n_axes)]
    values = [similarities[mod] for mod in comparison_modifiers]
    values_closed = values + [values[0]]
    angles_closed = angles + [angles[0]]

    view_width, view_height = 520, 540
    center_x = view_width / 2
    center_y = view_height / 2 + 22
    radius = 180

    def polar_to_cart(angle, r):
        return (center_x + r * np.cos(angle - np.pi / 2),
                center_y + r * np.sin(angle - np.pi / 2))

    svg_parts = [f"""
    <svg width="100%" height="100%" viewBox="0 0 {view_width} {view_height}"
         preserveAspectRatio="xMidYMid meet"
         style="background:#F8FAFC;border-radius:4px;border:1px solid #E5E7EB;
                width:100%;min-height:440px;display:block;">
      <text x="{center_x}" y="18" text-anchor="middle" font-size="15" font-weight="700" fill="#111827">
        Attribute Similarity Radar</text>
      <text x="{center_x}" y="34" text-anchor="middle" font-size="10" fill="#6B7280">
        s&#x0305; = exp(s&#x2081;)/(exp(s&#x2081;)+exp(s&#x2082;)) &#x00B7; bipolar Softmax [0,1] &#x00B7; opposite pairs 180&#xB0; apart</text>
      <text x="{center_x}" y="50" text-anchor="middle" font-size="11" fill="#374151">
        Selected: <tspan font-weight="700" fill="#8F0E2F">{selected_attribute}</tspan></text>"""]

    for i in range(1, 6):
        r = radius * i / 5
        svg_parts.append(f'<circle cx="{center_x}" cy="{center_y}" r="{r}" fill="none" stroke="#E5E7EB" stroke-width="1"/>')

    label_radius = radius + 30
    for i, angle in enumerate(angles):
        xe, ye = polar_to_cart(angle, radius)
        svg_parts.append(f'<line x1="{center_x}" y1="{center_y}" x2="{xe}" y2="{ye}" stroke="#CBD5E1" stroke-width="1"/>')
        lx, ly = polar_to_cart(angle, label_radius)
        mod = comparison_modifiers[i]
        is_selected = mod.lower() == selected_attribute.lower()
        fw = "700" if is_selected else "500"
        fc = "#8F0E2F" if is_selected else "#374151"
        lc = get_attr_color(mod) if fc == "#374151" else fc
        svg_parts.append(f'<text x="{lx}" y="{ly}" text-anchor="middle" dominant-baseline="middle" font-size="13" font-weight="{fw}" fill="{lc}">{mod}</text>')

    pts = " ".join(f"{polar_to_cart(a,radius*v)[0]},{polar_to_cart(a,radius*v)[1]}"
                   for a, v in zip(angles_closed, values_closed))
    svg_parts.append(f'<polygon points="{pts}" fill="rgba(143,14,47,0.18)" stroke="#8F0E2F" stroke-width="2.2"/>')

    for angle, value in zip(angles, values):
        x, y = polar_to_cart(angle, radius * value)
        svg_parts.append(f'<circle cx="{x}" cy="{y}" r="5.5" fill="#8F0E2F" stroke="#fff" stroke-width="2"/>')
        tx, ty = polar_to_cart(angle, radius * value + 12)
        svg_parts.append(f'<text x="{tx}" y="{ty}" text-anchor="middle" font-size="11" font-weight="600" fill="#1F2937">{value:.2f}</text>')

    svg_parts.append("</svg>")
    return "".join(svg_parts)


def generate_sensitivity_plot(selected_image, current_attribute, current_weight):
    if selected_image is None:
        return "<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>Select an image from gallery</div>"
    weights = np.linspace(1.0, 4.0, 13)
    distances = 1.0 - np.exp(-0.5 * (weights - 1.0))
    distances = distances / distances.max()
    distances += np.random.normal(0, 0.02, len(distances))
    distances = np.clip(distances, 0, 1)
    vw, vh = 420, 220; pl, pr, pt, pb = 46, 20, 20, 54
    pw, ph = vw - pl - pr, vh - pt - pb
    xc = pl + (weights - 1.0) / 3.0 * pw
    yc = pt + ph - distances * ph
    path = f"M {xc[0]},{yc[0]}" + "".join(f" L {x},{y}" for x, y in zip(xc[1:], yc[1:]))
    cx_ = pl + (current_weight - 1.0) / 3.0 * pw
    cy_ = pt + ph - np.interp(current_weight, weights, distances) * ph
    svg = f"""<svg width="100%" height="{vh}" viewBox="0 0 {vw} {vh}"
     preserveAspectRatio="xMidYMid meet"
     style="background:#F8FAFC;border-radius:4px;border:1px solid #E5E7EB;max-width:100%;">
  <text x="{vw/2}" y="{pt-4}" text-anchor="middle" font-size="11" font-weight="600" fill="#111827">Attribute Sensitivity</text>
  <line x1="{pl}" y1="{pt}" x2="{pl}" y2="{pt+ph}" stroke="#CBD5E1" stroke-width="1"/>
  <line x1="{pl}" y1="{pt+ph}" x2="{pl+pw}" y2="{pt+ph}" stroke="#CBD5E1" stroke-width="1"/>
  <text x="{pl-6}" y="{pt+ph}" text-anchor="end" font-size="8" fill="#9CA3AF">0.0</text>
  <text x="{pl-6}" y="{pt+ph*0.75}" text-anchor="end" font-size="8" fill="#9CA3AF">0.25</text>
  <text x="{pl-6}" y="{pt+ph*0.5}" text-anchor="end" font-size="8" fill="#9CA3AF">0.5</text>
  <text x="{pl-6}" y="{pt+ph*0.25}" text-anchor="end" font-size="8" fill="#9CA3AF">0.75</text>
  <text x="{pl-6}" y="{pt}" text-anchor="end" font-size="8" fill="#9CA3AF">1.0</text>
  <text x="{pl-28}" y="{pt+ph/2}" text-anchor="middle" font-size="9" fill="#6B7280"
        transform="rotate(-90,{pl-28},{pt+ph/2})">Embedding Distance</text>
  <text x="{pl}" y="{pt+ph+10}" text-anchor="middle" font-size="8" fill="#9CA3AF">1.0</text>
  <text x="{pl+pw/2}" y="{pt+ph+10}" text-anchor="middle" font-size="8" fill="#9CA3AF">2.5</text>
  <text x="{pl+pw}" y="{pt+ph+10}" text-anchor="middle" font-size="8" fill="#9CA3AF">4.0</text>
  <text x="{pl+pw/2}" y="{pt+ph+24}" text-anchor="middle" font-size="9" fill="#6B7280">Attribute Strength</text>
  <path d="{path}" fill="none" stroke="#8F0E2F" stroke-width="2.0"/>
  <circle cx="{cx_}" cy="{cy_}" r="4.5" fill="#8F0E2F" stroke="#fff" stroke-width="1.5"/>
  <text x="{vw/2}" y="{vh-22}" text-anchor="middle" font-size="9" fill="#374151">How much does the latent representation move when this attribute changes?</text>
  <text x="{vw/2}" y="{vh-10}" text-anchor="middle" font-size="9" fill="#4B5563">Sensitivity(s) = || z(s) - z(s-&#916;) ||,  z = image/latent embedding</text>
</svg>"""
    return svg

# ####################
# Cluster layout & gallery
# ####################


def generate_cluster_html():
    seen_attrs: dict = {}
    for batch in BATCH_HISTORY:
        a = batch.get("attribute", "")
        if a and a not in seen_attrs:
            seen_attrs[a] = get_attr_color(a)

    def _bcolor(attr: str) -> str:
        return seen_attrs.get(attr, get_attr_color(attr))

    all_images, all_batch_ids, all_subjects, all_clip_embs = [], [], [], []
    for batch_idx, batch in enumerate(BATCH_HISTORY):
        clip_batch = CLIP_HISTORY[batch_idx]["embeddings"] if batch_idx < len(CLIP_HISTORY) else None
        for i, thumb in enumerate(batch["thumbs"]):
            if thumb is not None:
                img_item = {
                    "img": batch["full"][i], "thumb": thumb,
                    "prompt": batch["prompt"], "seed": batch["seed"],
                    "subject": batch["subject"], "attribute": batch.get("attribute", ""),
                    "clip": clip_batch[i] if clip_batch and i < len(clip_batch) else None,
                }
                all_clip_embs.append(img_item["clip"])
                all_images.append(img_item)
                all_batch_ids.append(batch_idx)
                all_subjects.append(batch["subject"])

    if not all_images:
        return "<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:12px;'>No images yet</div>"

    imgs_only = [item["img"] for item in all_images]
    clip_for_layout = None if any(e is None for e in all_clip_embs) else all_clip_embs
    positions = compute_2d_layout_smart(imgs_only, all_batch_ids, all_subjects, clip_embeddings=clip_for_layout)

    html_parts = ["""<style>
    .image-browser{position:relative;width:100%;height:calc(100vh - 40px);
        background:linear-gradient(135deg,#F4F6FB 0%,#EEF0F8 100%);
        border-radius:6px;overflow:hidden;box-shadow:inset 0 2px 10px rgba(0,0,0,0.05);}
    .image-item{position:absolute;transition:all 0.3s cubic-bezier(0.4,0,0.2,1);
        filter:drop-shadow(0 4px 8px rgba(0,0,0,0.1));z-index:1;}
    .image-item:hover{transform:translate(-50%,-50%) scale(1.2)!important;
        filter:drop-shadow(0 8px 16px rgba(0,0,0,0.2));z-index:100!important;}
    .image-item.selected{transform:translate(-50%,-50%) scale(1.15)!important;
        filter:drop-shadow(0 6px 20px rgba(143,14,47,0.5));z-index:99!important;}
    .image-item img{width:70px;height:70px;border-radius:8px;border:3px solid #CCC;
        object-fit:cover;background:white;}
    .image-item.selected img{border-width:4px!important;
        box-shadow:0 0 0 2px #fff,0 0 0 5px rgba(0,0,0,0.3)!important;}
    .image-label{position:absolute;bottom:-22px;left:50%;transform:translateX(-50%);
        background:rgba(255,255,255,0.95);color:#111827;padding:2px 7px;border-radius:4px;
        font-size:10px;font-weight:600;white-space:nowrap;opacity:0;transition:opacity 0.3s;
        pointer-events:none;font-family:monospace;box-shadow:0 2px 8px rgba(0,0,0,0.1);
        border:1px solid #E5E7EB;}
    .image-item:hover .image-label{opacity:1;}
    </style><div class="image-browser" id="imageBrowser">"""]

    for idx, (item, pos) in enumerate(zip(all_images, positions)):
        x_pct, y_pct = pos[0] * 100, pos[1] * 100
        b64 = img_to_base64(item["thumb"].resize((70, 70), Image.Resampling.LANCZOS))
        sel = "selected" if SELECTED_IMAGE_INDEX == idx else ""
        html_parts.append(
            f'<div class="image-item {sel}" id="img-item-{idx}" data-index="{idx}" '
            f'style="left:{x_pct}%;top:{y_pct}%;transform:translate(-50%,-50%);">'
            f'<img src="data:image/png;base64,{b64}" alt="Image {idx}" '
            f'style="border:3px solid {_bcolor(item.get("attribute",""))}!important;">'
            f'<div class="image-label">{item.get("attribute","")} &middot; {item["seed"]}</div></div>'
        )
    if seen_attrs:
        pills = "".join(
            f'<div style="display:inline-flex;align-items:center;gap:5px;margin:0 4px 0 0;'
            f'padding:3px 9px;background:rgba(255,255,255,0.92);border:1.5px solid {col};'
            f'border-radius:20px;box-shadow:0 1px 4px rgba(0,0,0,0.10);">'
            f'<div style="width:10px;height:10px;border-radius:50%;background:{col};flex-shrink:0;"></div>'
            f'<span style="font-size:11px;font-weight:600;color:#111827;white-space:nowrap;">{attr}</span></div>'
            for attr, col in seen_attrs.items()
        )
        html_parts.append(
            f'<div style="position:absolute;top:8px;left:50%;transform:translateX(-50%);'
            f'display:flex;flex-wrap:wrap;justify-content:center;gap:4px;'
            f'z-index:200;max-width:95%;pointer-events:none;">{pills}</div>'
        )
    html_parts.append("</div>")
    return "".join(html_parts)


def create_gallery_data():
    gallery_images = []
    for batch_idx, batch in enumerate(BATCH_HISTORY):
        for i, thumb in enumerate(batch["thumbs"]):
            if thumb is not None:
                gallery_images.append((batch["full"][i], f"Seed: {batch['seed']} | {batch['subject']}"))
    return gallery_images


def on_gallery_select(evt: gr.SelectData):
    global SELECTED_IMAGE_INDEX
    index = evt.index
    all_images = []
    for batch_idx, batch in enumerate(BATCH_HISTORY):
        clip_batch = CLIP_HISTORY[batch_idx]["embeddings"] if batch_idx < len(CLIP_HISTORY) else None
        for i, thumb in enumerate(batch["thumbs"]):
            if thumb is not None:
                img_item = {
                    "img": batch["full"][i], "prompt": batch["prompt"],
                    "seed": batch["seed"], "attribute": batch.get("attribute", "cute"),
                    "weight": batch.get("weight", 2.0), "subject": batch.get("subject", "dog"),
                    "clip": clip_batch[i] if clip_batch and i < len(clip_batch) else None,
                }
                all_images.append(img_item)

    empty_plot = "<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>Select an image from gallery</div>"

    if 0 <= index < len(all_images):
        SELECTED_IMAGE_INDEX = index
        item = all_images[index]
        selected_attr = item["attribute"]
        dynamic_modifiers = list(COMPARISON_MODIFIERS)
        if selected_attr and selected_attr.lower() not in [m.lower() for m in dynamic_modifiers]:
            dynamic_modifiers.append(selected_attr)
        sensitivity_html = generate_sensitivity_plot(item["img"], selected_attr, item["weight"])
        radar_html = generate_radar_chart(
            item["img"], item["clip"], selected_attr, dynamic_modifiers,
            subject=item.get("subject", ""))
        return f"{item['prompt']}\nSeed: {item['seed']}", generate_cluster_html(), sensitivity_html, radar_html

    return "", generate_cluster_html(), empty_plot, empty_plot

# ####################
# Generation pipeline
# ####################


def generate_image(
    attribute_dropdown, attribute_custom, subject_dropdown, subject_custom,
    attribute_weight, art_style, art_style_weight, negative_prompt,
    seed, control_mode, steps, cfg, sampler, scheduler, denoise,
    width, height, batch_size, progress=gr.Progress(),
):
    global SELECTED_IMAGE_INDEX
    effective_attribute = get_effective_value(attribute_dropdown, attribute_custom)
    effective_subject = get_effective_value(subject_dropdown, subject_custom)
    empty_plot = "<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>Select an image from gallery</div>"
    if not effective_attribute:
        return ("Error: Please provide an attribute", generate_cluster_html(), [], "", empty_plot, empty_plot)
    if not effective_subject:
        return ("Error: Please provide a subject", generate_cluster_html(), [], "", empty_plot, empty_plot)
    progress(0, desc="Loading workflow...")
    base_prompt = build_base_prompt(effective_attribute, attribute_weight, effective_subject)
    combined_prompt = combine_prompt_with_style(base_prompt, art_style, art_style_weight)
    workflow = load_workflow_template()
    if not workflow:
        return ("Error: Could not load workflow template", generate_cluster_html(), [], "", empty_plot, empty_plot)
    if control_mode == "randomize":
        import random; seed = random.randint(0, 2**32 - 1)
    progress(0.2, desc="Updating parameters...")
    updated_workflow = update_workflow_parameters(
        workflow, combined_prompt, negative_prompt, seed, control_mode,
        steps, cfg, sampler, scheduler, denoise, width, height, batch_size)
    if not updated_workflow:
        return ("Error: Could not update workflow parameters", generate_cluster_html(), [], "", empty_plot, empty_plot)
    try:
        progress(0.3, desc="Queuing prompt...")
        response = requests.post(f"{COMFYUI_SERVER}/prompt", json={"prompt": updated_workflow}, timeout=30)
        if response.status_code != 200:
            return (f"Error: API returned status {response.status_code}", generate_cluster_html(), [], "", empty_plot, empty_plot)
        result = response.json()
        prompt_id = result.get("prompt_id")
        if not prompt_id:
            return ("Error: No prompt_id returned", generate_cluster_html(), [], "", empty_plot, empty_plot)
        progress(0.4, desc=f"Generating... ID: {prompt_id}")
        if not wait_for_completion(prompt_id, COMFYUI_SERVER):
            return ("Error: Generation timeout", generate_cluster_html(), [], "", empty_plot, empty_plot)
        progress(0.9, desc="Retrieving images...")
        images = get_images_from_comfy(prompt_id, COMFYUI_SERVER)
        clip_embs = get_clip_embeddings_from_comfy(prompt_id, COMFYUI_SERVER)
        if not images:
            return ("Error: No images generated", generate_cluster_html(), [], "", empty_plot, empty_plot)
        images = images[:batch_size]; actual_count = len(images)
        while len(images) < batch_size: images.append(None)
        if clip_embs and len(clip_embs) >= actual_count:
            clip_embs = clip_embs[:batch_size]
            while len(clip_embs) < batch_size: clip_embs.append(None)
        else:
            clip_embs = [None] * batch_size
        thumbs = [im.resize((70, 70), Image.Resampling.LANCZOS) if im else None for im in images]
        BATCH_HISTORY.append({
            "thumbs": thumbs, "full": images, "prompt": combined_prompt,
            "seed": seed, "subject": effective_subject,
            "attribute": effective_attribute, "weight": attribute_weight})
        CLIP_HISTORY.append({"embeddings": clip_embs})
        if len(BATCH_HISTORY) > MAX_BATCHES:
            removed = BATCH_HISTORY.pop(0); CLIP_HISTORY.pop(0) if CLIP_HISTORY else None
            removed_count = len([t for t in removed["thumbs"] if t is not None])
            if SELECTED_IMAGE_INDEX is not None:
                SELECTED_IMAGE_INDEX = max(-1, SELECTED_IMAGE_INDEX - removed_count)
        progress(1.0, desc="Complete!")
        total_images = sum(len([t for t in b["thumbs"] if t]) for b in BATCH_HISTORY)
        return (f"Generated {actual_count} image(s)!  ||  Total: {total_images}",
                generate_cluster_html(), create_gallery_data(), "", empty_plot, empty_plot)
    except requests.exceptions.ConnectionError:
        return ("Error: Could not connect to ComfyUI server", generate_cluster_html(), [], "", empty_plot, empty_plot)
    except Exception as e:
        import traceback; traceback.print_exc()
        return (f"Error: {str(e)}", generate_cluster_html(), [], "", empty_plot, empty_plot)


def update_preview(attribute_dropdown, attribute_custom, attribute_weight,
                   subject_dropdown, subject_custom, art_style, art_style_weight):
    eff_attr = get_effective_value(attribute_dropdown, attribute_custom)
    eff_subj = get_effective_value(subject_dropdown, subject_custom)
    if not eff_attr or not eff_subj: return ""
    return combine_prompt_with_style(build_base_prompt(eff_attr, attribute_weight, eff_subj), art_style, art_style_weight)


def toggle_attribute_custom(v): return gr.update(visible=(v == "Custom..."))
def toggle_subject_custom(v):   return gr.update(visible=(v == "Custom..."))

# ####################
# Frontend with Gradio
# ####################

with gr.Blocks(theme=gr.themes.Default(), css="""
html,body{margin:0!important;padding:0!important;height:100%!important;overflow-x:hidden!important;overflow-y:auto!important;}
.gradio-container{max-width:100vw!important;width:100vw!important;min-height:100vh!important;height:auto!important;background-color:#EDEBEB!important;padding:0!important;margin:0!important;overflow:visible!important;}
.panel{background:#FFFFFF!important;border-radius:0!important;border-right:1px solid #E5E7EB!important;padding:4px 8px!important;margin:0!important;max-height:calc(100vh - 32px)!important;height:auto!important;overflow-y:auto!important;}
.panel:last-child{border-right:none!important;}
.panel h3{display:none!important;}
.panel label{font-size:13px!important;font-weight:500!important;margin-bottom:0!important;color:#111827!important;}
.gradio-container .panel .wrap{margin-bottom:2px!important;}
.panel .accordion{margin-bottom:2px!important;}
.panel .wrap:has(button){margin-top:0!important;margin-bottom:2px!important;}
.panel input,.panel select,.panel textarea{font-size:13px!important;background:#FFFFFF!important;border:1px solid #E5E7EB!important;color:#111827!important;padding:3px 7px!important;}
button{background:#302F2F!important;color:white!important;font-weight:600!important;border-radius:6px!important;font-size:14px!important;padding-top:4px!important;padding-bottom:4px!important;}
button:hover{background:#252424!important;transform:translateY(-1px);box-shadow:0 4px 12px rgba(0,0,0,0.25)!important;}
.row{margin:0!important;gap:0!important;}.col{padding:0!important;margin:0!important;}
.panel .tabs{margin-bottom:0!important;padding-bottom:0!important;}
.panel .tabitem{padding-top:2px!important;padding-bottom:2px!important;}
input[type="range"]{accent-color:#302F2F!important;}
input[type="range"]::-webkit-slider-thumb{background:#302F2F!important;border:2px solid #302F2F!important;}
input[type="range"]::-webkit-slider-runnable-track{background:#302F2F22!important;}
input[type="range"]::-moz-range-thumb{background:#302F2F!important;border:2px solid #302F2F!important;}
input[type="range"]::-moz-range-track{background:#302F2F22!important;}
.accordion{border-radius:4px!important;margin-top:4px!important;background:#F8FAFC!important;border:1px solid #E5E7EB!important;}
.accordion>div:first-child{padding:4px 8px!important;}
.compact-plot{margin:0!important;padding:0!important;width:100%!important;overflow:hidden!important;background:#F8FAFC!important;}
.compact-plot svg{max-width:100%!important;height:auto!important;}
""") as demo:
    with gr.Row(elem_classes="row"):
        with gr.Column(scale=3, elem_classes="panel"):
            attribute_dropdown = gr.Dropdown(label="Attribute", choices=BASE_ATTRIBUTES, value="cute")
            attribute_custom   = gr.Textbox(label="Custom Attribute", placeholder="e.g., playful, mystical", visible=False)
            attribute_weight   = gr.Slider(label="Attribute Strength", minimum=0.1, maximum=4.0, value=2.0, step=0.1)
            subject_dropdown   = gr.Dropdown(label="Subject", choices=BASE_SUBJECTS, value="dog")
            subject_custom     = gr.Textbox(label="Custom Subject", placeholder="e.g., wolf, spaceship", visible=False)
            positive_prompt_preview = gr.Textbox(label="Auto-generated Prompt", interactive=False, lines=2)
            art_style          = gr.Dropdown(label="Art Style", choices=list(ART_STYLES.keys()), value="Fantasy")
            art_style_weight   = gr.Slider(label="Art Style Strength", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
            negative_prompt    = gr.Textbox(label="Negative Prompt", lines=2, value=DEFAULT_NEGATIVE_PROMPT)
            with gr.Accordion("Advanced Settings", open=False):
                seed         = gr.Number(label="Seed", value=1093030236344156, precision=0)
                control_mode = gr.Dropdown(label="Control", choices=["randomize","fixed","increment","decrement"], value="randomize")
                steps        = gr.Slider(label="Steps", minimum=1, maximum=150, value=30, step=1)
                cfg          = gr.Slider(label="CFG", minimum=0, maximum=20, value=7.0, step=0.1)
                sampler      = gr.Dropdown(label="Sampler", choices=["euler","euler_ancestral","dpm_2"], value="euler")
                scheduler    = gr.Dropdown(label="Scheduler", choices=["sgm_uniform","normal","karras"], value="sgm_uniform")
                denoise      = gr.Slider(label="Denoise", minimum=0, maximum=1, value=1.0, step=0.01)
                width        = gr.Slider(label="Width",  minimum=512, maximum=1024, value=512, step=64)
                height       = gr.Slider(label="Height", minimum=512, maximum=1024, value=512, step=64)
                batch_size   = gr.Slider(label="Batch Size", minimum=1, maximum=4, value=4, step=1)
            generate_btn = gr.Button("Generate Images", variant="primary", size="lg")
            status_text  = gr.Textbox(label="Status", interactive=False, lines=1)

        with gr.Column(scale=8, elem_classes="panel col"):
            with gr.Tabs():
                with gr.Tab("UMAP View"):
                    cluster_canvas = gr.HTML(value="<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:12px;'>Generate images to see spatial clustering</div>")
                with gr.Tab("Gallery"):
                    image_gallery  = gr.Gallery(label="All Generated Images", show_label=False, columns=6, rows=2, height=600, object_fit="cover")

        with gr.Column(scale=4, elem_classes="panel col"):
            selected_prompt  = gr.Textbox(label="Prompt & Seed", lines=3, interactive=False)
            sensitivity_plot = gr.HTML(value="<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>Select an image from gallery</div>", elem_classes="compact-plot")
            radar_plot       = gr.HTML(value="<div style='padding:8px;text-align:center;color:#9CA3AF;font-size:11px;'>Select an image to see attribute similarity</div>", elem_classes="compact-plot")

    attribute_dropdown.change(fn=toggle_attribute_custom, inputs=[attribute_dropdown], outputs=[attribute_custom])
    subject_dropdown.change(fn=toggle_subject_custom, inputs=[subject_dropdown], outputs=[subject_custom])
    for comp in [attribute_dropdown, attribute_custom, attribute_weight,
                 subject_dropdown, subject_custom, art_style, art_style_weight]:
        comp.change(fn=update_preview,
                    inputs=[attribute_dropdown, attribute_custom, attribute_weight,
                            subject_dropdown, subject_custom, art_style, art_style_weight],
                    outputs=[positive_prompt_preview])
    generate_btn.click(fn=generate_image,
        inputs=[attribute_dropdown, attribute_custom, subject_dropdown, subject_custom,
                attribute_weight, art_style, art_style_weight, negative_prompt,
                seed, control_mode, steps, cfg, sampler, scheduler, denoise,
                width, height, batch_size],
        outputs=[status_text, cluster_canvas, image_gallery, selected_prompt, sensitivity_plot, radar_plot])
    image_gallery.select(fn=on_gallery_select,
        outputs=[selected_prompt, cluster_canvas, sensitivity_plot, radar_plot])


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)