In [1]:
import os
import json
import hashlib
from datetime import datetime
from PIL import Image, PngImagePlugin

import unibox as ub
logger = ub.UniLogger()


gpt2 client:

In [2]:
# gpt2 api
from gradio_client import Client

def get_gpt2_pred_old(client, prompt:str, max_length:int, models:list[str]):
	# prompt: str
	# max_length: float (numeric value between 10 and 300)
	# return: str
	result = client.predict(
		prompt,
		max_length,
		models,
		api_name="/predict"
	)
	return result





def get_gpt2_pred(client, rating:str, character:str, prompt:str, max_length:int, models:list[str]):
	
	assert rating in ['safe', 'sensitive', 'nsfw', 'nsfw, explicit']
	
	result = client.predict(
			rating,	# Literal['safe', 'sensitive', 'nsfw', 'nsfw, explicit'] in 'Rating' Radio component
			"2020s",	# Literal['2000s', '2010s', '2015s', '2020s'] in 'Date' Radio component
			"excellent",	# Literal['bad', 'normal', 'good', 'excellent'] in 'Quality' Radio component
			character,	# str in 'Character' Textbox component
			prompt,	# str in 'prompt' Textbox component
			max_length,	# float (numeric value between 40 and 300) in 'max_length' Slider component
			models,	# List[Literal['checkpoint-e0_s12000', 'checkpoint-e0_s28000', 'checkpoint-e0_s48000']] in 'Select Models' Checkboxgroup component
			api_name="/predict"
	)
	return result


# GPT2_ENDPOINT = "https://ab2b7e0b874cc07fb9.gradio.live/"
# client = Client(GPT2_ENDPOINT) # version epoch_0_batch_11727
# get_gpt2_pred(client, "hatsune miku", 100, ["epoch_0_batch_17619"])

webui:

In [3]:
# https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38C7-L38C27
from evalit.webui.webui_t2i_client import WebuiT2iClient
from evalit.webui.webui_t2i_client import WebuiT2iClient, SdxlGenerationConfig
from evalit.webui.webui_options_manager import OptionsManager


def save_image(image_paths:list[str], param_strs:list[str], api_args, save_dir:str="saved_images"):
    """保存一个webui生成结果到本地 (可能包括多张图)

    """

    saved_files = []

    for i, (image, param) in enumerate(zip(image_paths, param_strs)):
        os.makedirs(save_dir, exist_ok=True)
        pnginfo = PngImagePlugin.PngInfo()
        pnginfo.add_text("parameters", param)
        pnginfo.add_text("api_args", json.dumps(api_args))

        # get filename
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        img_hash = hashlib.md5(image.tobytes()).hexdigest()        
        file_name = f"{timestamp}_{img_hash[:4]}.png"        
        file_path = os.path.join(save_dir, file_name)

        # save file
        image.save(file_path, pnginfo=pnginfo)
        saved_files.append(file_path)  # Add the saved file path to the list

    return saved_files

def generate_and_save_images(prompt:str, config:SdxlGenerationConfig, save_dir:str="saved_images",):
    """从webui api roll图, 然后返回本地保存路径
    (假设模型已经在webui里手动调到需要的那个)
    """
    # initialize client
    client = WebuiT2iClient(baseurl="http://0.0.0.0:7862")

    # initialize options manager
    client_base_url = client.baseurl
    options_manager = OptionsManager(client_base_url)


    # generate images
    images, param_strs, api_args = client.generate(prompt, config)  # negative defined in config

    # save images
    saved_files = save_image(images, param_strs, api_args, save_dir) # list of paths to saved images
    return saved_files

# save_dir = "saved_images"
# img_path = generate_and_save_images("a cat")[0]
# display(ub.loads(img_path))

gradio:

In [4]:
import gradio as gr
from PIL import Image



def get_sdxl_generation_config(steps:int=24):
    """returns a SdxlGenerationConfig object with modifiable steps param
    """
    
    # SDXL generation config
    # https://github.com/troph-team/eval-it/blob/aa0cb59983e2b0385ef03328b2ce6a3c36a073a0/evalit/webui/webui_t2i_client.py#L38  
    config = SdxlGenerationConfig() 
    config.sampler_name="Euler a"
    config.cfg_scale=6.5
    config.height=1280
    config.width=768
    config.steps=steps

    return SdxlGenerationConfig()


def stitch_images_horizontally(imgs:list[Image.Image]):
    """横向拼接多张图片"""
    stitch_img = Image.new('RGB', (imgs[0].width * len(imgs), imgs[0].height))
    for i, img in enumerate(imgs):
        stitch_img.paste(img, (img.width * i, 0))
    return stitch_img


def roll_image(rating:str, character:str, prompt:str, prompt_len:int, img_steps:int, prompt_only:bool):

    orig_prompt = prompt
    logger.info(f"got prompt: {prompt}")

    # get gpt2 predictions
    gpt2_res = get_gpt2_pred(client, rating, character, prompt, prompt_len, gpt2_models)
    logger.info(f"got gpt2 res of len={len(gpt2_res)}")

    # get prompt return string
    prompt_return_str = "\n\n".join([f"[{k}] {v}" for k, v in gpt2_res.items()])
    
    if prompt_only:
        return prompt_return_str, None


    # generate images with webui API
    img_results = []
    config = get_sdxl_generation_config(steps=img_steps)
    for model_key, model_res in gpt2_res.items():
        _model_res = model_res.replace("</output>", "")
        curr_img_results = generate_and_save_images(prompt=_model_res, config=config)
        logger.info(f"model: {model_key} | prompt: {_model_res} | image: {curr_img_results}")
        img_results.append(curr_img_results)
    
    # load iamge paths to PIL
    image_paths = [curr_img_res[0] for curr_img_res in img_results]
    imgs = [ub.loads(img_path) for img_path in image_paths] # list of PIL images
    
    # stitch images horizontally
    stitch_img = stitch_images_horizontally(imgs)
    logger.info("stitched images")

    return prompt_return_str, stitch_img

gpt2_models = [
    # 'checkpoint-e0_s32000',
    # 'checkpoint-e0_s68000',
    '8xh100_run2_e2_s50k',
]


description = f"""
# LM-Augmented SDXL Demo
Augments the input prompt with gpt-2, then generates 2 images for comparison. takes about 50 seconds to run.
 - generated prompts: {" | ".join(["original"] + gpt2_models)}
 - generation config: Euler a | cfg6.5
"""


GPT2_ENDPOINT = "https://44b07123159cab6c40.gradio.live"
client = Client(GPT2_ENDPOINT) # version epoch_0_batch_11727


inputs = [
    gr.Radio(choices=["safe", "sensitive", "nsfw", "nsfw, explicit"], label="Rating", value="safe"),
    gr.Textbox(lines=1, placeholder="Enter your prompt here...", label="Character"),
    gr.Textbox(label="Enter prompt (comma-separated danbooru tags recommended)", placeholder="hatsune miku, aqua hair"), 
    gr.Slider(label="Prompt Length", minimum=60, maximum=300, step=10, value=160),
    gr.Slider(label="Image Steps", minimum=8, maximum=40, step=2, value=18),
    gr.Checkbox(label="Prompt Only", value=False),
    ]
outputs = [
    gr.Textbox(label="Generated Prompts"), 
    gr.Image(label="Generated Images"), 
    ]

# Define the Gradio interface
interface = gr.Interface(fn=roll_image,
                         inputs=inputs,
                         outputs=outputs,
                         description=description,
                         )

# Launch the Gradio app
interface.launch(share=True)

Loaded as API: https://44b07123159cab6c40.gradio.live/ ✔
Running on local URL:  http://127.0.0.1:7863
Running on public URL: https://df168bfa1186ccc13f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




2024-03-19 21:26:30,533 [INFO] UniLogger: roll_image: got prompt: 1girl
2024-03-19 21:26:32,218 [INFO] UniLogger: roll_image: got gpt2 res of len=2
2024-03-19 21:26:52,060 [INFO] UniLogger: roll_image: model: original | prompt: <input rating="safe" chara="" date="2020s" quality="excellent" tags="1girl"><output> | image: ['saved_images/20240319212651_9285.png']
2024-03-19 21:27:11,377 [INFO] UniLogger: roll_image: model: 8xh100_run2_e2_s50k | prompt: 1girl, mole under eye, underwear, green eyes, blue eyes, looking at viewer, white legwear, panties under pantyhose, bare shoulders, collarbone, medium breasts, nail polish, crotch seam, see-through, cowboy shot, parted lips, eyebrows visible through hair, standing, white pantyhose, kaede takagaki, the idolmaster: cinderella girls, cleavage | image: ['saved_images/20240319212711_f2e5.png']
2024-03-19 21:27:11,378 [INFO] UniLogger: UniLoader.loads: .png LOADED from "saved_images/20240319212651_9285.png" in 0.00s
2024-03-19 21:27:11,379 [INFO]