In [1]:
import os
import json
import hashlib
from datetime import datetime
from PIL import Image, PngImagePlugin

import unibox as ub
logger = ub.UniLogger()


gpt2 client:

In [2]:
# gpt2 api
from gradio_client import Client

GPT2_ENDPOINT = "https://ab2b7e0b874cc07fb9.gradio.live/"
client = Client(GPT2_ENDPOINT) # version epoch_0_batch_11727



def get_gpt2_pred(client, prompt:str, max_length:int, models:list[str]):
	# prompt: str
	# max_length: float (numeric value between 10 and 300)
	# return: str
	result = client.predict(
		prompt,
		max_length,
		models,
		api_name="/predict"
	)
	return result

get_gpt2_pred(client, "hatsune miku", 100, ["epoch_0_batch_17619"])

Loaded as API: https://ab2b7e0b874cc07fb9.gradio.live/ ✔


{'original': 'hatsune miku',
 'epoch_0_batch_17619': 'hatsune miku, 1girl, twintails, thighhighs, very long hair, aqua hair, detached sleeves, necktie, nail polish, aqua eyes, smile, blush, zettai ryouiki, open mouth, headphones, striped, headphones around neck, black legwear, hair ornament, shirt, 39, miku100, long hair, solo, skirt, headset, headphones, green hair, ahoge, thighhighs, tattoo, ahoge over one ey'}

webui:

In [3]:
# https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38C7-L38C27
from evalit.webui.webui_t2i_client import WebuiT2iClient
from evalit.webui.webui_t2i_client import WebuiT2iClient, SdxlGenerationConfig
from evalit.webui.webui_options_manager import OptionsManager


def save_image(image_paths:list[str], param_strs:list[str], api_args, save_dir:str="saved_images"):
    """保存一个webui生成结果到本地 (可能包括多张图)

    """

    saved_files = []

    for i, (image, param) in enumerate(zip(image_paths, param_strs)):
        os.makedirs(save_dir, exist_ok=True)
        pnginfo = PngImagePlugin.PngInfo()
        pnginfo.add_text("parameters", param)
        pnginfo.add_text("api_args", json.dumps(api_args))

        # get filename
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        img_hash = hashlib.md5(image.tobytes()).hexdigest()        
        file_name = f"{timestamp}_{img_hash[:4]}.png"        
        file_path = os.path.join(save_dir, file_name)

        # save file
        image.save(file_path, pnginfo=pnginfo)
        saved_files.append(file_path)  # Add the saved file path to the list

    return saved_files

def generate_and_save_images(prompt:str, save_dir:str="saved_images"):
    """从webui api roll图, 然后返回本地保存路径
    (假设模型已经在webui里手动调到需要的那个)
    """
    # initialize client
    client = WebuiT2iClient(baseurl="http://0.0.0.0:7862")

    # initialize options manager
    client_base_url = client.baseurl
    options_manager = OptionsManager(client_base_url)

    # https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38  
    config = SdxlGenerationConfig() 
    config.sampler_name="Euler a"
    config.steps=24
    config.cfg_scale=6.5
    config.height=1280
    config.width=768

    # generate images
    images, param_strs, api_args = client.generate(prompt, config)  # negative defined in config

    # save images
    saved_files = save_image(images, param_strs, api_args, save_dir) # list of paths to saved images
    return saved_files

# save_dir = "saved_images"
# img_path = generate_and_save_images("a cat")[0]
# display(ub.loads(img_path))

gradio:

In [4]:
import gradio as gr
from PIL import Image


gpt2_models = [
    'epoch_0_batch_17619',
    'epoch_0_batch_52986',
    'epoch_0_batch_120933',
]



def stitch_images_horizontally(imgs:list[Image.Image]):
    """横向拼接多张图片"""
    stitch_img = Image.new('RGB', (imgs[0].width * len(imgs), imgs[0].height))
    for i, img in enumerate(imgs):
        stitch_img.paste(img, (img.width * i, 0))
    return stitch_img


def roll_image(prompt:str):

    orig_prompt = prompt
    logger.info(f"got prompt: {prompt}")

    gpt2_res = get_gpt2_pred(client, prompt, 60, gpt2_models)
    logger.info(f"got gpt2 res of len={len(gpt2_res)}")

    img_results = []
    for model_key, model_res in gpt2_res.items():
        curr_img_results = generate_and_save_images(model_res)
        logger.info(f"model: {model_key} | prompt: {model_res} | image: {curr_img_results}")
        img_results.append(curr_img_results)
    
    image_paths = [curr_img_res[0] for curr_img_res in img_results]
    imgs = [ub.loads(img_path) for img_path in img_paths] # list of PIL images

    # stitch horizontally
    stitch_img = stitch_images_horizontally(imgs)
    logger.info("stitched images")

    prompt_return_str = "\n\n".join([f"[{k}] {v}" for k, v in gpt2_res.items()])

    return prompt_return_str, stitch_img

description = f"""# LM-Augmented SDXL Demo
 Augments the input prompt with gpt-2, then generates 2 images for comparison. takes about 30 seconds to run.
 - generated prompts: {" | ".join(gpt2_models)}
 - generation config: Euler a | cfg6.5 | **24 steps**
"""

inputs = [gr.Textbox(label="Enter prompt (comma-separated danbooru tags)", placeholder="hatsune miku"), ]
outputs = [
    gr.Textbox(label="Generated Prompts"), 
    gr.Image(label="Generated Images"), 
    ]

# Define the Gradio interface
interface = gr.Interface(fn=roll_image,
                         inputs=inputs,
                         outputs=outputs,
                         description=description,
                         )

# Launch the Gradio app
interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7865
Running on public URL: https://4a7ae081d8b932fca0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




2024-03-16 22:21:36,622 [INFO] UniLogger: roll_image: got prompt: 1girl, long hair, white hair
2024-03-16 22:21:39,403 [INFO] UniLogger: roll_image: got gpt2 res of len=4
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/gradio/queueing.py", line 501, in call_prediction
    output = await route_utils.call_process_api(
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/gradio/route_utils.py", line 253, in call_process_api
    output = await app.get_blocks().process_api(
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1695, in process_api
    result = await self.call_function(
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1235, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File