In [None]:
# =========================================================
# ðŸ”¹ STEP 1: Install dependencies
# =========================================================
!pip install torch torchvision transformers diffusers accelerate safetensors sentencepiece

# =========================================================
# ðŸ”¹ STEP 2: Import required libraries
# =========================================================
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from diffusers import StableDiffusionPipeline
from PIL import Image
import matplotlib.pyplot as plt

# =========================================================
# ðŸ”¹ STEP 3: Detect device (GPU or CPU)
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("âœ… Device in use:", device)

# =========================================================
# ðŸ”¹ STEP 4: Load LLM model (FLAN-T5) for text refinement
# =========================================================
print("ðŸ”„ Loading LLM model (FLAN-T5)...")
llm_model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name).to(device)

# =========================================================
# ðŸ”¹ STEP 5: Load Stable Diffusion pipeline
# =========================================================
print("ðŸŽ¨ Loading Stable Diffusion model...")
dtype = torch.float16 if device == "cuda" else torch.float32

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=dtype
)
pipe = pipe.to(device)

# =========================================================
# ðŸ”¹ STEP 6: Define helper functions
# =========================================================
def refine_prompt(prompt):
    """
    Use LLM to make the prompt more detailed and descriptive.
    """
    input_text = f"Generate a detailed and descriptive image prompt for: {prompt}"
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    outputs = llm_model.generate(**inputs, max_length=60)
    refined = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return refined


def generate_image_from_text(prompt):
    """
    Refines the text prompt using LLM and generates image using Stable Diffusion.
    """
    print("\nðŸ§  Refining your prompt using LLM...")
    refined_prompt = refine_prompt(prompt)
    print(f"\nâœ¨ Refined Prompt:\n{refined_prompt}\n")

    print("ðŸŽ¨ Generating image, please wait...")
    image = pipe(refined_prompt).images[0]

    # Save and display image
    image.save("generated_image.png")
    print("âœ… Image saved as generated_image.png")

    plt.imshow(image)
    plt.axis("off")
    plt.title("Generated Image")
    plt.show()

# =========================================================
# ðŸ”¹ STEP 7: Run the project
# =========================================================
user_input = input("Enter your text prompt: ")
generate_image_from_text(user_input)


âœ… Device in use: cpu
ðŸ”„ Loading LLM model (FLAN-T5)...
ðŸŽ¨ Loading Stable Diffusion model...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Enter your text prompt: A futuristic city floating above clouds

ðŸ§  Refining your prompt using LLM...

âœ¨ Refined Prompt:
A futuristic city floating above clouds.

ðŸŽ¨ Generating image, please wait...


  0%|          | 0/50 [00:00<?, ?it/s]