In [1]:
from PIL import Image, PngImagePlugin
import hashlib
import json
from datetime import datetime
import os

import unibox as ub

gpt2 client:

In [2]:
# gpt2 api
from gradio_client import Client

GPT2_ENDPOINT = "https://e45d0e6c55d3364043.gradio.live/"
client = Client(GPT2_ENDPOINT) # version epoch_0_batch_11727


def get_gpt2_pred(prompt, max_length):
	# prompt: str
	# max_length: float (numeric value between 10 and 300)
	# return: str
	result = client.predict(
		prompt,
		max_length,
		api_name="/predict"
	)
	return result

get_gpt2_pred("hatsune miku", 100)

Loaded as API: https://e45d0e6c55d3364043.gradio.live/ ✔


'hatsune miku, 1girl, blonde hair, twintails, closed eyes, open mouth, smile, hair ornament, very long hair, sleeveless, white dress, hair flower, outdoors, outstretched arms, bare shoulders, bangs, :d, white footwear, armpits, hair between eyes, arm up, blurry background, depth of field, spread arms, standing on one leg, cloudy sky, pink flower, full body, white pantyhose, VOCALOI'

webui:

In [3]:
# https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38C7-L38C27
from evalit.webui.webui_t2i_client import WebuiT2iClient
from evalit.webui.webui_t2i_client import WebuiT2iClient, SdxlGenerationConfig
from evalit.webui.webui_options_manager import OptionsManager


def save_image(image_paths:list[str], param_strs:list[str], api_args, save_dir:str="saved_images"):
    """保存一个webui生成结果到本地 (可能包括多张图)

    """

    saved_files = []

    for i, (image, param) in enumerate(zip(image_paths, param_strs)):
        os.makedirs(save_dir, exist_ok=True)
        pnginfo = PngImagePlugin.PngInfo()
        pnginfo.add_text("parameters", param)
        pnginfo.add_text("api_args", json.dumps(api_args))

        # get filename
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        img_hash = hashlib.md5(image.tobytes()).hexdigest()        
        file_name = f"{timestamp}_{img_hash[:4]}.png"        
        file_path = os.path.join(save_dir, file_name)

        # save file
        image.save(file_path, pnginfo=pnginfo)
        saved_files.append(file_path)  # Add the saved file path to the list

    return saved_files

def generate_and_save_images(prompt:str, save_dir:str="saved_images"):
    """从webui api roll图, 然后返回本地保存路径
    (假设模型已经在webui里手动调到需要的那个)
    """
    # initialize client
    client = WebuiT2iClient(baseurl="http://0.0.0.0:7862")

    # initialize options manager
    client_base_url = client.baseurl
    options_manager = OptionsManager(client_base_url)

    # https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38  
    config = SdxlGenerationConfig() 
    config.sampler_name="Euler a"
    config.steps=24
    config.cfg_scale=6.5

    # generate images
    images, param_strs, api_args = client.generate(prompt, config)  # negative defined in config

    # save images
    saved_files = save_image(images, param_strs, api_args, save_dir) # list of paths to saved images
    return saved_files

# save_dir = "saved_images"
# img_path = generate_and_save_images("a cat")[0]
# display(ub.loads(img_path))

gradio:

In [4]:
import gradio as gr

def roll_image(prompt:str):
    print("got prompt", prompt)
    orig_prompt = prompt

    augmented_prompt = get_gpt2_pred(prompt, 60)
    print("got augmented prompt", augmented_prompt)

    orig_img_results = generate_and_save_images(orig_prompt)
    print(f"got image0: {orig_img_results}")

    augmented_img_results = generate_and_save_images(augmented_prompt)
    print(f"got image1: {augmented_img_results}")
    
    img_paths = [orig_img_results[0], augmented_img_results[0]]
    imgs = [ub.loads(img_path) for img_path in img_paths]

    # stitch horizontally
    stitch_img = Image.new('RGB', (imgs[0].width + imgs[1].width, imgs[0].height))
    stitch_img.paste(imgs[0], (0, 0))
    stitch_img.paste(imgs[1], (imgs[0].width, 0))
    print("stitched images")

    prompt_return_str = f"**Original:** {orig_prompt}\n\n**Augmented:** {augmented_prompt}"

    return prompt_return_str, stitch_img

description = """# LM-Augmented SDXL Demo
 Augments the input prompt with gpt-2, then generates 2 images for comparison. takes about 20 seconds to run.
 - model: fulldan-5m-9e 
 - generation config: Euler a; cfg6.5; 24 steps
"""

inputs = [gr.Textbox(label="Enter prompt (comma-separated danbooru tags)", placeholder="hatsune miku"), ]
outputs = [
    gr.Textbox(label="Generated Prompts"), 
    gr.Image(label="Generated Images"), 
    ]

# Define the Gradio interface
interface = gr.Interface(fn=roll_image,
                         inputs=inputs,
                         outputs=outputs,
                         description=description,
                         )

# Launch the Gradio app
interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7863
