In [None]:
import requests
from PIL import Image
from io import BytesIO
from IPython.display import display, clear_output
import ipywidgets as widgets

In [None]:
API_URL = "http://127.0.0.1:8000/generate"

In [None]:
def generate_image(prompt, size, filename):
    payload = {
        "prompt": prompt,
        "size": size,
        "filename": filename,
    }
    print("Generating image... Please wait.")

    response = requests.post(API_URL, json=payload)
    if response.status_code == 200:
        print(f"Image '{filename}' generated successfully!")
        image_data = BytesIO(response.content)
        return Image.open(image_data)
    else:
        print("Error:", response.json())
        return None


def save_image_locally(image, filename):
    image.save(filename)
    print(f"Image saved locally as '{filename}'.")


In [None]:
prompt_widget = widgets.Text(
    value="A cozy cabin in the snow with northern lights",
    description="Prompt:",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="80%")
)

size_widget = widgets.Dropdown(
    options=["1024x1024", "1792x1024", "1024x1792"],
    value="1024x1024",
    description="Size:",
    style={"description_width": "initial"}
)

filename_widget = widgets.Text(
    value="image.png",
    description="Filename:",
    style={"description_width": "initial"}
)

button_generate = widgets.Button(description="Generate Image", button_style="success")
button_save = widgets.Button(description="Save Image", button_style="info", disabled=True)
output = widgets.Output()

last_generated_image = None

def on_generate_click(b):
    global last_generated_image
    with output:
        clear_output(wait=True)
        image = generate_image(
            prompt_widget.value,
            size_widget.value,
            filename_widget.value,
        )
        if image:
            last_generated_image = image
            display(image)
            button_save.disabled = False

button_generate.on_click(on_generate_click)


def on_save_click(b):
    global last_generated_image
    if last_generated_image:
        filename = filename_widget.value
        save_image_locally(last_generated_image, filename)
    else:
        print("No image to save. Please generate an image first.")

button_save.on_click(on_save_click)

widgets.VBox([
    prompt_widget,
    size_widget,
    filename_widget,
    widgets.HBox([button_generate, button_save]),
    output
])