In [1]:
import clip
import torch
import random
import spacy
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as T
import ast
import torch
import os 
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPProcessor, CLIPModel
from groq import Groq  # Assuming Groq has an OpenAI-compatible client

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# === Load CLIP Model ===
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

In [3]:

# Initialize the client with your API key
client=Groq(api_key="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")


In [4]:
# ----------- Set device for Stable Diffusion -----------
device = "cuda" if torch.cuda.is_available() else "cpu"

# ----------- Load Stable Diffusion XL base model -----------
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    variant="fp16",  # Use if you're on a GPU like NVIDIA RTX
    use_safetensors=True
).to("cuda")  # or your `device`

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 11.25it/s]


In [5]:
NEGATIVE_PROMPT = (
    "blurry, deformed, ugly, bad anatomy, bad proportions, low quality, "
    "mutated, out of frame, extra limbs, poorly drawn, jpeg artifacts, "
    "text, watermark, signature"
)


In [6]:
prompt_a = "dragon "
prompt_b = "forest" 

In [7]:

def combine_prompts_interactively(prompt_a, prompt_b, client):
    """
    Combines two prompts into five blended prompts using LLaMA 3 model,
    incorporating alpha weights between 0.4 and 0.6 to influence emphasis.

    Parameters:
        prompt_a (str): First input prompt 
        prompt_b (str): Second input prompt
        client: Pre-authenticated OpenAI-style API client for llama3-70b-8192

    Returns:
        dict: Dictionary of prompts with keys like 'prompt_0.2', 'prompt_0.4', etc.
    """

    system_prompt = (
        "You are a visual scene composer. Your job is to merge two image generation prompts "
        "into a single, meaningful visual scene. Ensure that the subjects of both prompts interact "
        "in a realistic and visually coherent way. "
        "if the subject is a living -there should be no facial features like eyes ,nose or mouth"
        "if subject is a person then he should be fully clothed in a modern attire, his face should be realisitc with no distortions and All limbs, hands, and body proportions appear natural and anatomically correct. "
        "Each time, emphasize one prompt more than the other using an alpha value between 0.4 and 0.6. Output a single descriptive sentence suitable for a text-to-image model. "
        "keep less than 15 words"
        "the image should be photorealistic"
        "avoid AI type images "
        "Return only the prompt. No explanations."
    )

    alphas = np.linspace(0.40, 0.60, 5)
    prompt_dict = {}


    for alpha in alphas:
        alpha_rounded = round(alpha, 1)
        user_prompt = (
            f"Prompt 1 (weight {1 - alpha_rounded:.1f}): {prompt_a}\n"
            f"Prompt 2 (weight {alpha_rounded:.1f}): {prompt_b}\n\n"
            "Create one coherent prompt that blends them into an interactive visual scene, reflecting the given weights."
        )

        response = client.chat.completions.create(
            model="llama3-70b-8192",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0.7
        )

        combined_prompt = response.choices[0].message.content.strip()
        var_name = f"prompt_{alpha_rounded}"
        prompt_dict[var_name] = combined_prompt

        
        

    return prompt_dict




In [8]:
import ast
import re

def clean_and_parse_weights(raw_response):
    # Attempt to extract the dictionary block
    match = re.search(r'\{[\s\S]+\}', raw_response)
    if not match:
        raise ValueError("No dictionary found in response.")

    text = match.group(0)

    # Fix common malformed patterns
    text = text.replace('":', '": ')           # Ensure spacing after colons
    text = text.replace('",', '", ')           # Ensure spacing after commas
    text = re.sub(r'(\d)\s*"}', r'\1"}', text) # Fix trailing number then quote
    text = re.sub(r'":\s*([0-9.]+)"', r'": \1', text)  # Fix numeric values wrapped in quotes
    text = re.sub(r'([0-9.])"([,\}])', r'\1\2', text)  # Remove trailing quote after numbers

    try:
        parsed = ast.literal_eval(text)
    except Exception as e:
        raise ValueError(f"Cleaned version still invalid: {e}\nCleaned:\n{text}")

    if not isinstance(parsed, dict) or "Final Prompt" not in parsed:
        raise ValueError(f"Structure incorrect. Parsed:\n{parsed}")

    return parsed


def get_token_weights(final_prompt, client):
    user_prompt = f"""
Extract the main subjects and objects from the prompt and 
assign:
- 1.7 to subject and object 
- 1.2 to secondary subjects and objects 
- 1.0 to other important details

"Final Prompt": {final_prompt}

Respond ONLY with a valid Python dictionary in this format:
{{
  "Final Prompt": {{"word1": weight, "word2": weight, ...}}
}}

No explanations, no extra text.
"""

    response = client.chat.completions.create(
        model="llama3-70b-8192",
        messages=[{"role": "user", "content": user_prompt}],
        temperature=0.5,
    )

    raw_response = response.choices[0].message.content.strip()
    try:
        return clean_and_parse_weights(raw_response)
    except Exception as e:
        raise ValueError(f"Failed to parse token weights: {e}\nRaw response:\n{raw_response}")


In [9]:
# Step 1: Generate combined prompts
blended_prompts = combine_prompts_interactively(prompt_a, prompt_b, client)

# Step 2: Extract token weights for each combined prompt
weights_per_prompt = {}

print("\n--- Token Weights per Blended Prompt ---\n")

for name, final_prompt in blended_prompts.items():
    try:
        weights = get_token_weights(final_prompt, client)
        weights_per_prompt[name] = weights
        print(f"{name}:\n{weights}\n")
    except Exception as e:
        print(f"Failed to extract weights for {name}: {e}\n")




--- Token Weights per Blended Prompt ---

prompt_0.4:
{'Final Prompt': {'dragon': 1.7, 'scales': 1.7, 'sunlight': 1.2, 'trees': 1.2, 'ancient': 1.0, 'majestic': 1.0, 'calmly': 1.0, 'amidst': 1.0, 'lies': 1.0, 'glistening': 1.0, 'dappled': 1.0}}

prompt_0.5:
{'Final Prompt': {'dragon': 1.7, 'forest': 1.7, 'clearing': 1.2, 'scales': 1.2, 'foliage': 1.2, 'misty': 1.0}}

prompt_0.6:
{'Final Prompt': {'dragon': 1.7, 'forest': 1.7, 'trees': 1.2, 'floor': 1.2, 'misty': 1.0}}



In [10]:
import re

def apply_weights(prompt, weights_dict):
    for word, weight in weights_dict.items():
        # Use regex to replace whole word only, case-sensitive
        pattern = r'\b' + re.escape(word) + r'\b'
        prompt = re.sub(pattern, f"({word}:{weight})", prompt)
    return prompt

# Step 3: Apply weights to each blended prompt
weighted_prompts = {}

print("\n--- Weighted Prompts ---\n")

for name, final_prompt in blended_prompts.items():
    try:
        token_weights = weights_per_prompt[name]
        weighted_prompt = apply_weights(final_prompt, token_weights["Final Prompt"])
        weighted_prompts[name] = weighted_prompt
        print(f"{name}:\n{weighted_prompt}\n")
    except Exception as e:
        print(f"Failed to apply weights for {name}: {e}\n")




--- Weighted Prompts ---

prompt_0.4:
A (majestic:1.0) (dragon:1.7), with (scales:1.7) (glistening:1.0) in (dappled:1.0) (sunlight:1.2), (lies:1.0) (calmly:1.0) (amidst:1.0) (ancient:1.0) (trees:1.2).

prompt_0.5:
A majestic (dragon:1.7), (scales:1.2) glistening, lies calmly amidst dense (foliage:1.2) in a (misty:1.0) (forest:1.7) (clearing:1.2).

prompt_0.6:
A majestic (dragon:1.7) sprawls across a (misty:1.0) (forest:1.7) (floor:1.2), surrounded by towering (trees:1.2).



In [11]:
weighted_prompts

{'prompt_0.4': 'A (majestic:1.0) (dragon:1.7), with (scales:1.7) (glistening:1.0) in (dappled:1.0) (sunlight:1.2), (lies:1.0) (calmly:1.0) (amidst:1.0) (ancient:1.0) (trees:1.2).',
 'prompt_0.5': 'A majestic (dragon:1.7), (scales:1.2) glistening, lies calmly amidst dense (foliage:1.2) in a (misty:1.0) (forest:1.7) (clearing:1.2).',
 'prompt_0.6': 'A majestic (dragon:1.7) sprawls across a (misty:1.0) (forest:1.7) (floor:1.2), surrounded by towering (trees:1.2).'}

In [12]:
# === Encode Prompts Once ===
def encode_texts(texts):
    tokens = clip.tokenize(texts,truncate=True).to(device)
    with torch.no_grad():
        embeddings = clip_model.encode_text(tokens).float()
        embeddings /= embeddings.norm(dim=-1, keepdim=True)
    return embeddings

base_embeddings = encode_texts([prompt_a, prompt_b])
text_A, text_B = base_embeddings

text_embeddings = {}  # Store embeddings per prompt

for key, prompt in weighted_prompts.items():
    embedding = encode_texts([prompt])[0]  # Single embedding
    text_embeddings[key] = embedding

In [13]:
num_variants = 3
guidance_scale = 7.5
num_inference_steps = 50
lambda_=0.6

In [14]:
def generate_and_score_all(weighted_prompts, num_variants=3, num_inference_steps=50, guidance_scale=7.5, lambda_=0.5):
    all_results = {}
    base_embeddings = encode_texts([prompt_a, prompt_b])
    text_A, text_B = base_embeddings  # shape: [512]

    for key, prompt in weighted_prompts.items():
        # Encode current blended prompt
        text_blend = encode_texts([prompt])[0]

        best_score = -float("inf")
        best_image = None
        best_seed = None

        for _ in range(num_variants):
            seed = random.randint(0, 100000)
            generator = torch.Generator(device=device).manual_seed(seed)
            image = pipe(
                prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                generator=generator
            ).images[0]

            image_tensor = clip_preprocess(image).unsqueeze(0).to(device)
            with torch.no_grad():
                image_embed = clip_model.encode_image(image_tensor).float()
                image_embed /= image_embed.norm(dim=-1, keepdim=True)

                sim_A = (image_embed @ text_A.T).item()
                sim_B = (image_embed @ text_B.T).item()
                sim_blend = (image_embed @ text_blend.T).item()

                final_score = lambda_ * sim_blend + (1 - lambda_) * min(sim_A, sim_B)

            if final_score > best_score:
                best_score = final_score
                best_image = image
                best_seed = seed

            print(f"[{key}] Seed: {seed} | Score: {final_score:.4f} | sim_A: {sim_A:.4f} | sim_B: {sim_B:.4f} | sim_blend: {sim_blend:.4f}")

        all_results[key] = {
            "image": best_image,
            "seed": best_seed,
            "score": best_score
        }

    return all_results  # Dictionary: key -> best result for each prompt


In [None]:


import gradio as gr
def generate_final_image(prompt1, prompt2):
    # Step 1: Blend prompts
    blended_prompts = combine_prompts_interactively(prompt1, prompt2, client)

    # Step 2: Get token weights for each blended prompt
    weights_dicts = {
        k: get_token_weights(v, client)["Final Prompt"]
        for k, v in blended_prompts.items()
    }

    # Step 3: Apply weights to each prompt
    weighted_prompts = {
        k: apply_weights(blended_prompts[k], weights_dicts[k])
        for k in blended_prompts
    }

    print("Weighted / Blended Prompts:")
    for k, v in weighted_prompts.items():
        print(f"α = {k}: {v}")

    # Step 4: Generate & score images
    results = generate_and_score_all(
        weighted_prompts,
        num_variants=3,
        num_inference_steps=50,
        guidance_scale=7.5,
        lambda_=0.5
    )

    # Step 5: Select best image
    best_key = max(results, key=lambda k: results[k]["score"])
    best_image = results[best_key]["image"]

    return best_image


# Launch Gradio interface
iface = gr.Interface(
    fn=generate_final_image,
    inputs=[
        gr.Textbox(label="Prompt 1"),
        gr.Textbox(label="Prompt 2")
    ],
    outputs=gr.Image(type="pil"),
    title="Prompt Blending Image Generator",
    description="Enter two prompts. The system will blend them, generate multiple images, and return the best one."
)

iface.launch(share=True)

* Running on local URL:  http://127.0.0.1:7864
* Running on public URL: https://a0afe371c00be83b88.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Token indices sequence length is longer than the specified maximum sequence length for this model (84 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['( trees : 1 . 2 ).']


Weighted / Blended Prompts:
α = prompt_0.4: A (corgi:1.7) (puppy:1.7) (plays:1.0) (near:1.0) the (serene:1.0) (lake:1.2)'s (edge:1.0), (surrounded:1.0) by (sunflowers:1.2) and (autumn:1.2) (trees:1.2).
α = prompt_0.5: A (corgi:1.7) (puppy:1.7) (plays:1.0) near a (lake:1.2)'s (edge:1.2), (surrounded:1.0) by (sunflowers:1.2) and (reflected:1.0) (autumn:1.2) (trees:1.0).
α = prompt_0.6: A (corgi:1.7) (puppy:1.7) sits at the (lake:1.2)'s edge, surrounded by vibrant (autumn:1.2) (trees:1.2), reflected in calm (waters:1.2).


100%|██████████| 50/50 [00:10<00:00,  4.57it/s]
  sim_A = (image_embed @ text_A.T).item()
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['( trees : 1 . 2 ).']


[prompt_0.4] Seed: 53129 | Score: 0.2353 | sim_A: 0.1769 | sim_B: 0.1831 | sim_blend: 0.2936


100%|██████████| 50/50 [00:11<00:00,  4.28it/s]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['( trees : 1 . 2 ).']


[prompt_0.4] Seed: 58521 | Score: 0.2134 | sim_A: 0.1657 | sim_B: 0.1704 | sim_blend: 0.2610


100%|██████████| 50/50 [00:11<00:00,  4.36it/s]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [').']


[prompt_0.4] Seed: 47814 | Score: 0.2625 | sim_A: 0.1914 | sim_B: 0.1968 | sim_blend: 0.3336


100%|██████████| 50/50 [00:11<00:00,  4.33it/s]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [').']


[prompt_0.5] Seed: 79871 | Score: 0.2711 | sim_A: 0.1898 | sim_B: 0.1913 | sim_blend: 0.3523


100%|██████████| 50/50 [00:11<00:00,  4.34it/s]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [').']


[prompt_0.5] Seed: 87265 | Score: 0.2543 | sim_A: 0.1791 | sim_B: 0.1935 | sim_blend: 0.3295


100%|██████████| 50/50 [00:11<00:00,  4.32it/s]


[prompt_0.5] Seed: 86308 | Score: 0.2720 | sim_A: 0.1862 | sim_B: 0.2010 | sim_blend: 0.3578


100%|██████████| 50/50 [00:11<00:00,  4.32it/s]


[prompt_0.6] Seed: 15812 | Score: 0.2299 | sim_A: 0.1824 | sim_B: 0.2351 | sim_blend: 0.2775


100%|██████████| 50/50 [00:11<00:00,  4.30it/s]


[prompt_0.6] Seed: 34240 | Score: 0.2841 | sim_A: 0.1938 | sim_B: 0.2155 | sim_blend: 0.3743


100%|██████████| 50/50 [00:11<00:00,  4.31it/s]


[prompt_0.6] Seed: 51351 | Score: 0.2862 | sim_A: 0.1989 | sim_B: 0.2299 | sim_blend: 0.3734
Weighted / Blended Prompts:
α = prompt_0.4: A 1960s vintage (car:1.7) overlooks a (serene:0) (autumn:1.2) (lake:1.7) through (Parisian:1.0) (streets:1.2)' (misty:1.0) (veil:1.2).
α = prompt_0.5: A 1960s vintage (car:1.7) overlooks a serene (autumn:1.0)-colored (lake:1.7) surrounded by (trees:1.2) in the French (countryside:1.2).
α = prompt_0.6: A 1960s (vintage:1.0) (car:1.7) overlooks a (serene:1.0) (autumn:1.2) (lake:1.7), (surrounded:1.0) (by:1.0) vibrant (trees:1.7), in (Paris:1.2).


100%|██████████| 50/50 [00:10<00:00,  4.57it/s]


[prompt_0.4] Seed: 10322 | Score: 0.2665 | sim_A: 0.1824 | sim_B: 0.2206 | sim_blend: 0.3506


100%|██████████| 50/50 [00:10<00:00,  4.58it/s]


[prompt_0.4] Seed: 43003 | Score: 0.2583 | sim_A: 0.1797 | sim_B: 0.2008 | sim_blend: 0.3369


100%|██████████| 50/50 [00:10<00:00,  4.55it/s]


[prompt_0.4] Seed: 76022 | Score: 0.2397 | sim_A: 0.1657 | sim_B: 0.1823 | sim_blend: 0.3138


100%|██████████| 50/50 [00:11<00:00,  4.52it/s]


[prompt_0.5] Seed: 95670 | Score: 0.2478 | sim_A: 0.1853 | sim_B: 0.2250 | sim_blend: 0.3103


100%|██████████| 50/50 [00:11<00:00,  4.51it/s]


[prompt_0.5] Seed: 96815 | Score: 0.2455 | sim_A: 0.1812 | sim_B: 0.2232 | sim_blend: 0.3098


100%|██████████| 50/50 [00:11<00:00,  4.50it/s]


[prompt_0.5] Seed: 88363 | Score: 0.2377 | sim_A: 0.1815 | sim_B: 0.2243 | sim_blend: 0.2940


100%|██████████| 50/50 [00:11<00:00,  4.49it/s]


[prompt_0.6] Seed: 54239 | Score: 0.2550 | sim_A: 0.1703 | sim_B: 0.1938 | sim_blend: 0.3397


100%|██████████| 50/50 [00:11<00:00,  4.49it/s]


[prompt_0.6] Seed: 87661 | Score: 0.2691 | sim_A: 0.1722 | sim_B: 0.2016 | sim_blend: 0.3659


100%|██████████| 50/50 [00:11<00:00,  4.47it/s]


[prompt_0.6] Seed: 31573 | Score: 0.2756 | sim_A: 0.1883 | sim_B: 0.2293 | sim_blend: 0.3628
