# L4: Describe-and-Generate game 🖍️

Load your HF API key and relevant Python libraries

In [1]:
import os
import io
from IPython.display import Image, display, HTML
from PIL import Image
import base64 

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
hf_api_key = os.environ['HF_API_KEY']

In [2]:
#### Helper function
import requests, json

#Here we are going to call multiple endpoints!
def get_completion(inputs, parameters=None, ENDPOINT_URL=""):
    headers = {
      "Authorization": f"Bearer {hf_api_key}",
      "Content-Type": "application/json"
    }   
    data = { "inputs": inputs }
    if parameters is not None:
        data.update({"parameters": parameters})
    response = requests.request("POST",
                                ENDPOINT_URL,
                                headers=headers,
                                data=json.dumps(data))
    return json.loads(response.content.decode("utf-8"))

In [3]:
#text-to-image
TTI_ENDPOINT = os.environ['HF_API_TTI_BASE']
#image-to-text
ITT_ENDPOINT = os.environ['HF_API_ITT_BASE']

## Building your game with `gr.Blocks()`

In [4]:
#Bringing the functions from lessons 3 and 4!
def image_to_base64_str(pil_image):
    byte_arr = io.BytesIO()
    pil_image.save(byte_arr, format='PNG')
    byte_arr = byte_arr.getvalue()
    return str(base64.b64encode(byte_arr).decode('utf-8'))

def base64_to_pil(img_base64):
    base64_decoded = base64.b64decode(img_base64)
    byte_stream = io.BytesIO(base64_decoded)
    pil_image = Image.open(byte_stream)
    return pil_image

def captioner(image):
    base64_image = image_to_base64_str(image)
    result = get_completion(base64_image, None, ITT_ENDPOINT)
    return result[0]['generated_text']

def generate(prompt):
    output = get_completion(prompt, None, TTI_ENDPOINT)
    result_image = base64_to_pil(output)
    return result_image

### First attempt, just captioning

In [5]:
"""
import gradio as gr 
with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️")
    image_upload = gr.Image(label="Your first image",type="pil")
    btn_caption = gr.Button("Generate caption")
    caption = gr.Textbox(label="Generated caption")
    
    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])

gr.close_all()
demo.launch(share=True, server_port=int(os.environ['PORT1']))
"""

'\nimport gradio as gr \nwith gr.Blocks() as demo:\n    gr.Markdown("# Describe-and-Generate game 🖍️")\n    image_upload = gr.Image(label="Your first image",type="pil")\n    btn_caption = gr.Button("Generate caption")\n    caption = gr.Textbox(label="Generated caption")\n\n    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])\n\ngr.close_all()\ndemo.launch(share=True, server_port=int(os.environ[\'PORT1\']))\n'

In [7]:
# one-time (in your venv with uv):
# uv pip install "torch>=2.2" torchvision accelerate transformers pillow gradio

import os, torch, gradio as gr
from transformers import pipeline

# Use MPS on Apple silicon; fall back to CPU if unavailable
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")  # safer on MPS

# Lazy-load a single global pipeline (fast on subsequent calls)
_CAPTION_PIPE = None
def get_pipe():
    global _CAPTION_PIPE
    if _CAPTION_PIPE is None:
        try:
            # Small & fast; great default for local captioning
            _CAPTION_PIPE = pipeline(
                "image-to-text",
                model="nlpconnect/vit-gpt2-image-captioning",
                device=device,                # send to MPS/CPU
            )
        except TypeError:
            # Older Transformers may not accept torch.device directly
            _CAPTION_PIPE = pipeline(
                "image-to-text",
                model="nlpconnect/vit-gpt2-image-captioning",
                device=0 if device.type == "cuda" else -1
            ).to(device)
    return _CAPTION_PIPE

def captioner(pil_image):
    if pil_image is None:
        return "Please upload an image."
    pipe = get_pipe()
    with torch.inference_mode():
        out = pipe(pil_image, max_new_tokens=24)   # small cap for speed
    return out[0]["generated_text"]

with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️ (Local, Fast)")

    with gr.Row():
        with gr.Column(scale=4):
            image_upload = gr.Image(label="Your first image", type="pil")
        with gr.Column(scale=1, min_width=50):
            btn_caption = gr.Button("Generate caption")

    caption = gr.Textbox(label="Generated caption")

    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])

gr.close_all()
demo.launch(share=False, server_port=int(os.environ.get('PORT1', '7860')))


Closing server running on port: 7860
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




Device set to use mps


### Let's add generation

In [8]:
"""
with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️")
    image_upload = gr.Image(label="Your first image",type="pil")
    btn_caption = gr.Button("Generate caption")
    caption = gr.Textbox(label="Generated caption")
    btn_image = gr.Button("Generate image")
    image_output = gr.Image(label="Generated Image")
    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])
    btn_image.click(fn=generate, inputs=[caption], outputs=[image_output])

gr.close_all()
demo.launch(share=True, server_port=int(os.environ['PORT2']))
"""

'\nwith gr.Blocks() as demo:\n    gr.Markdown("# Describe-and-Generate game 🖍️")\n    image_upload = gr.Image(label="Your first image",type="pil")\n    btn_caption = gr.Button("Generate caption")\n    caption = gr.Textbox(label="Generated caption")\n    btn_image = gr.Button("Generate image")\n    image_output = gr.Image(label="Generated Image")\n    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])\n    btn_image.click(fn=generate, inputs=[caption], outputs=[image_output])\n\ngr.close_all()\ndemo.launch(share=True, server_port=int(os.environ[\'PORT2\']))\n'

In [9]:
# one-time installs (inside your venv with uv):
# uv pip install "torch>=2.2" torchvision accelerate transformers diffusers safetensors pillow gradio

import os, torch, gradio as gr
from transformers import pipeline
from diffusers import AutoPipelineForText2Image
from PIL import Image

# ---- MPS-friendly defaults for Apple silicon ----
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")

DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.float32 if DEVICE == "mps" else (torch.float16 if DEVICE == "cuda" else torch.float32)

# ---- Lazy-load local captioner (small & fast) ----
_CAPTION_PIPE = None
def _get_caption_pipe():
    global _CAPTION_PIPE
    if _CAPTION_PIPE is None:
        _CAPTION_PIPE = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
        # move to MPS/CUDA when possible
        try:
            if DEVICE != "cpu":
                _CAPTION_PIPE.model.to(DEVICE)
        except Exception:
            pass
    return _CAPTION_PIPE

# ---- Lazy-load local T2I (SDXL-Turbo = super fast drafts) ----
_TTI_PIPE = None
def _get_tti_pipe():
    global _TTI_PIPE
    if _TTI_PIPE is None:
        _TTI_PIPE = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sdxl-turbo",
            torch_dtype=DTYPE,              # keep f32 on MPS to avoid NaNs/black frames
            # variant="fp16"  # (only if you force CUDA fp16; skip on MPS)
        ).to(DEVICE)
        if hasattr(_TTI_PIPE, "set_progress_bar_config"):
            _TTI_PIPE.set_progress_bar_config(disable=True)
        if hasattr(_TTI_PIPE, "enable_attention_slicing"):
            _TTI_PIPE.enable_attention_slicing()
        if hasattr(_TTI_PIPE, "enable_vae_tiling"):
            _TTI_PIPE.enable_vae_tiling()
    return _TTI_PIPE

def _snap8(x: int) -> int:
    return int(max(8, round(int(x) / 8) * 8))

# ---- Handlers ----
def captioner(pil_image: Image.Image):
    if pil_image is None:
        return "Please upload an image."
    pipe = _get_caption_pipe()
    with torch.inference_mode():
        out = pipe(pil_image, max_new_tokens=24)  # short = fast
    return out[0]["generated_text"]

def generate(prompt: str):
    if not prompt or not prompt.strip():
        raise gr.Error("No prompt. Use the generated caption or type your own.")
    pipe = _get_tti_pipe()
    with torch.inference_mode():
        img = pipe(
            prompt.strip(),
            num_inference_steps=2,   # SDXL-Turbo sweet spot: 1–4
            guidance_scale=0.0,      # works best at 0.0
            width=_snap8(512),
            height=_snap8(512),
        ).images[0]
    return img

# ---- Your UI (unchanged layout) ----
with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️")
    image_upload = gr.Image(label="Your first image", type="pil")
    btn_caption = gr.Button("Generate caption")
    caption = gr.Textbox(label="Generated caption")
    btn_image = gr.Button("Generate image")
    image_output = gr.Image(label="Generated Image", type="pil")

    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])
    btn_image.click(fn=generate, inputs=[caption], outputs=[image_output])

gr.close_all()
# Let Gradio pick a free port for reliability/speed
_port = os.environ.get("PORT2")
if _port:
    demo.launch(share=False, server_port=int(_port))
else:
    demo.launch(share=False)


Closing server running on port: 7860
* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


Device set to use mps:0


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


### Doing it all at once

In [10]:
"""
def caption_and_generate(image):
    caption = captioner(image)
    image = generate(caption)
    return [caption, image]

with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️")
    image_upload = gr.Image(label="Your first image",type="pil")
    btn_all = gr.Button("Caption and generate")
    caption = gr.Textbox(label="Generated caption")
    image_output = gr.Image(label="Generated Image")

    btn_all.click(fn=caption_and_generate, inputs=[image_upload], outputs=[caption, image_output])

gr.close_all()
demo.launch(share=True, server_port=int(os.environ['PORT3']))
"""

'\ndef caption_and_generate(image):\n    caption = captioner(image)\n    image = generate(caption)\n    return [caption, image]\n\nwith gr.Blocks() as demo:\n    gr.Markdown("# Describe-and-Generate game 🖍️")\n    image_upload = gr.Image(label="Your first image",type="pil")\n    btn_all = gr.Button("Caption and generate")\n    caption = gr.Textbox(label="Generated caption")\n    image_output = gr.Image(label="Generated Image")\n\n    btn_all.click(fn=caption_and_generate, inputs=[image_upload], outputs=[caption, image_output])\n\ngr.close_all()\ndemo.launch(share=True, server_port=int(os.environ[\'PORT3\']))\n'

In [11]:
# one-time installs (inside your venv with uv):
# uv pip install "torch>=2.2" torchvision accelerate transformers diffusers safetensors pillow gradio

import os, torch, gradio as gr
from transformers import pipeline
from diffusers import AutoPipelineForText2Image
from PIL import Image

# --- MPS-friendly env ---
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")

DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.float32 if DEVICE == "mps" else (torch.float16 if DEVICE == "cuda" else torch.float32)

# --- Lazy global pipelines (load once, reuse = fast) ---
_CAPTION_PIPE = None
_TTI_PIPE = None

def _get_caption_pipe():
    global _CAPTION_PIPE
    if _CAPTION_PIPE is None:
        _CAPTION_PIPE = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
        try:
            if DEVICE != "cpu":
                _CAPTION_PIPE.model.to(DEVICE)
        except Exception:
            pass
    return _CAPTION_PIPE

def _get_tti_pipe():
    global _TTI_PIPE
    if _TTI_PIPE is None:
        _TTI_PIPE = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sdxl-turbo",
            torch_dtype=DTYPE,            # keep f32 on MPS for stability
        ).to(DEVICE)
        if hasattr(_TTI_PIPE, "set_progress_bar_config"): _TTI_PIPE.set_progress_bar_config(disable=True)
        if hasattr(_TTI_PIPE, "enable_attention_slicing"): _TTI_PIPE.enable_attention_slicing()
        if hasattr(_TTI_PIPE, "enable_vae_tiling"): _TTI_PIPE.enable_vae_tiling()
    return _TTI_PIPE

def captioner(pil_image: Image.Image) -> str:
    if pil_image is None:
        return "Please upload an image."
    pipe = _get_caption_pipe()
    with torch.inference_mode():
        out = pipe(pil_image, max_new_tokens=24)  # short = fast
    # typical: [{'generated_text': '...'}]
    return out[0].get("generated_text", str(out[0]))

def generate(prompt: str) -> Image.Image:
    if not prompt or not prompt.strip():
        raise gr.Error("No prompt. Use the generated caption or type your own.")
    pipe = _get_tti_pipe()
    with torch.inference_mode():
        img = pipe(
            prompt.strip(),
            num_inference_steps=2,   # SDXL-Turbo sweet spot: 1–4
            guidance_scale=0.0,      # SDXL-Turbo works best at 0.0
            width=512, height=512,   # fast and stable; feel free to change
        ).images[0]
    return img

def caption_and_generate(image: Image.Image):
    cap = captioner(image)
    img = generate(cap)
    return [cap, img]

# ---------------- UI ----------------
with gr.Blocks() as demo:
    gr.Markdown("# Describe-and-Generate game 🖍️ (Local & Fast)")
    image_upload = gr.Image(label="Your first image", type="pil")
    btn_all = gr.Button("Caption and generate", variant="primary")
    caption = gr.Textbox(label="Generated caption")
    image_output = gr.Image(label="Generated Image", type="pil")

    btn_all.click(fn=caption_and_generate, inputs=[image_upload], outputs=[caption, image_output])

gr.close_all()
# Let Gradio auto-pick a free port; use PORT3 if you really want a fixed port
_port = os.environ.get("PORT3")
demo.launch(share=False, server_port=int(_port) if _port else None)


Closing server running on port: 7861
* Running on local URL:  http://127.0.0.1:7862
* To create a public link, set `share=True` in `launch()`.




Device set to use mps:0


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [12]:
gr.close_all()

Closing server running on port: 7862
