In [None]:
import os
import gc

import torch
import gradio as gr

from dataclasses import asdict, dataclass
from textwrap import dedent
from types import SimpleNamespace

from huggingface_hub import login
from transformers import pipeline
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers import DiffusionPipeline

In [None]:
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'


# initialize key components
# task = kwargs.get("task", "text-to-image")
# model_id = kwargs.get("model_id", "black-forest-labs/FLUX.1-dev")
# world_size = kwargs.get("world_size", 1)
# rank = kwargs.get("rank", 0)
# master_add = kwargs.get("master_add", "127.0.0.1")
# master_port = kwargs.get("master_port", "12345")
# project_id = project

task = "text-to-image"
base_model_id = "black-forest-labs/FLUX.1-dev"

css = """
    .importantButton {
        background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
        border: none !important;
    }
    .importantButton:hover {
        background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
        border: none !important;
    }
    .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
    .xsmall {font-size: x-small;}
"""

example_list = [
    "A cat balancing on a pole",
    "2 cats fighting each other",
    "A big river flowing near a mountain",
    "Red hair girl, anime style",
    "Black hair girl, oil painting",
    "She sells seashell by the seashore"
]

model_list = {
    "FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
    "FLUX.1-schnell": "black-forest-labs/FLUX.1-schnell",
    "Flux-Super-Realism-LoRA": "strangerzonehf/Flux-Super-Realism-LoRA"
}


# print(f'''\
# Project ID: {project_id}
# Label config: {self.label_config}
# Parsed JSON Label config: {self.parsed_label_config}''')
hf_access_token = "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"
# hf_access_token = kwargs.get("hf_access_token", "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")
login(token=hf_access_token)




In [None]:
# check gpu(s)
n_gpus = torch.cuda.device_count()
try:
    _ = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
except AssertionError:
    _ = 0
max_memory = {i: _ for i in range(n_gpus)}
print('max memory:', max_memory)

gc.collect()
torch.cuda.empty_cache()

In [None]:
@dataclass
class Config:
  guidance_scale = 3.0
  step = 1
  width = 64
  height = 64
  prompt = ""

STATS_DEFAULT = SimpleNamespace(llm=None, config=Config())

In [None]:
# handler

# button to generate image from input text
def generate_btn_handler(model: str, prompt: str, guidance_scale: float, step: int, width: int,
                          height: int) -> tuple:
    if prompt == "" or prompt is None:
        return None, ""

    # model = model_list[model]

    compute_capability = torch.cuda.get_device_properties(0).major
    if compute_capability > 8:
        torch_dtype = torch.bfloat16
    elif compute_capability>7:
        torch_dtype = torch.float16
    else:
        torch_dtype = None  # auto setup for < 7
    
    try:
        pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch_dtype)
    except Exception as e:
        base_model = base_model_id
        pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
        pipe.load_lora_weights(model)

    # # for low GPU RAM, quantize from 16b to 8b
    # quantize(pipe.transformer, weights=qfloat8)
    # freeze(pipe.transformer)
    # quantize(pipe.text_encoder_2, weights=qfloat8)
    # freeze(pipe.text_encoder_2)

    # # for even lower GPU RAM
    # pipe.vae.enable_tiling()
    # pipe.vae.enable_slicing()

    pipe.enable_sequential_cpu_offload()

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    image = pipe(
        prompt=prompt,
        width=width,
        height=height,
        num_inference_steps=step,
        guidance_scale=guidance_scale,
        generator=torch.Generator(device=device)
    ).images[0]

    pipe = None
    torch.cuda.empty_cache()

    return image, ""


In [None]:
# gradio ui

with gr.Blocks(
        theme=gr.themes.Soft(text_size="sm"),
        title="Flux Image Generator",
        css=css, ) as demo_txt_to_img:

    stats = gr.State(STATS_DEFAULT)
    config = asdict(stats.value.config)

    with gr.Row():
        # model = gr.Dropdown(list(model_list.keys()), label="Select VLLM Model", type="value")
        model = gr.Textbox(value=base_model_id, label="Select VLLM Model from Huggingface/local repo", type="text", interactive=True, )
    with gr.Row():
        image_field = gr.Image(label="Output Image", elem_id="output_image")
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.TextArea(label="Prompt:", elem_id="small-textarea", lines=10, max_lines=8)
            generate_btn = gr.Button("Generate")
        with gr.Column(scale=1):
            guidance_scale = gr.Slider(value=STATS_DEFAULT.config.guidance_scale, minimum=0.0, maximum=30.0,
                                        step=0.1, label="Guidance scale")
            step = gr.Slider(value=STATS_DEFAULT.config.step, minimum=1, maximum=100, step=1, label="Step")
            width = gr.Number(value=STATS_DEFAULT.config.width, label='Image width (64-1024)', precision=0,
                              minimum=64, maximum=1024, interactive=True)
            height = gr.Number(value=STATS_DEFAULT.config.width, label='Image height (64-1024)', precision=0,
                                minimum=64, maximum=1024, interactive=True)

    with gr.Accordion("Example inputs", open=True):
        examples = gr.Examples(
            examples=example_list,
            inputs=[prompt],
            examples_per_page=60,
        )

    # Event handlers
    generate_btn.click(fn=generate_btn_handler,
                        inputs=[model, prompt, guidance_scale, step, width, height],
                        outputs=[image_field, prompt],
                        api_name="generate")

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("Flux VLLM")
    with gr.Tabs():
        if task == "text-to-image":
            with gr.Tab(label=task):
                demo_txt_to_img.render()
        # else:
            # return {"share_url": "", 'local_url': ""}


In [None]:
# launch gradio app
gradio_app, local_url, share_url = demo.launch(share=True, quiet=True, prevent_thread_lock=True,
                                                server_name='0.0.0.0', show_error=True)
