Skip to content

Commit

Permalink
Add image variations
Browse files Browse the repository at this point in the history
  • Loading branch information
seratch committed May 24, 2024
1 parent 5c5c362 commit 6562d85
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 1 deletion.
154 changes: 154 additions & 0 deletions app/bolt_listeners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import logging
import re
import threading
import time
from typing import List

import requests
from openai import APITimeoutError
Expand All @@ -21,6 +23,7 @@
from app.openai_image_ops import (
append_image_content_if_exists,
generate_image,
generate_image_variations,
)
from app.openai_ops import (
start_receiving_openai_response,
Expand Down Expand Up @@ -67,6 +70,11 @@
build_image_generation_result_modal,
build_image_generation_text_modal,
build_image_generation_result_blocks,
build_image_variations_text_modal,
build_image_variations_result_blocks,
build_image_variations_result_modal,
build_image_variations_wip_modal,
build_image_variations_input_modal,
)


Expand Down Expand Up @@ -845,6 +853,144 @@ def display_image_generation_result(
)


def start_image_variations(client: WebClient, body: dict, payload: dict):
client.views_open(
trigger_id=body.get("trigger_id"),
view=build_image_variations_input_modal(payload.get("value")),
)


def ack_image_variations_modal_submission(ack: Ack):
ack(response_action="update", view=build_image_variations_wip_modal())


def display_image_variations_result(
client: WebClient,
context: BoltContext,
logger: logging.Logger,
payload: dict,
):
try:
# https://platform.openai.com/docs/guides/images/variations-dall-e-2-only
model = "dall-e-2" # DALL·E 2 only
size = extract_state_value(payload, "size").get("selected_option").get("value")
image_files = extract_state_value(payload, "input_files").get("files")

start_time = time.time()
threads = []
file_uploads: List[dict] = []
try:
for image_file in image_files:

def generate_variations():
uploaded_image_url = image_file["url_private"]
image_data: bytes = requests.get(
uploaded_image_url,
headers={"Authorization": f"Bearer {context.bot_token}"},
).content

image_url = generate_image_variations(
context=context,
image=image_data,
size=size,
timeout_seconds=OPENAI_TIMEOUT_SECONDS,
)
image_content = requests.get(image_url).content
file_uploads.append(
{"file": image_content, "filename": image_file["name"]}
)

thread = threading.Thread(target=generate_variations)
thread.daemon = True
thread.start()
threads.append(thread)

finally:
for t in threads:
try:
if t.is_alive():
t.join()
except Exception:
pass

spent_seconds = str(round((time.time() - start_time), 2))

if len(file_uploads) == 0:
logger.error("Failed to prepare any upload content")
client.views_update(
view_id=payload["id"],
view=build_image_variations_text_modal(
"Failed to generate variations. Please check your OpenAI platform usage."
),
)
return

users = [context.actor_user_id]
dm_id = client.conversations_open(users=users)["channel"]["id"]
message_text = (
"Here are the generated image variations for your inputs:\n"
f"model: {model}, size: {size}, time spent: {spent_seconds} s"
)
upload = client.files_upload_v2(
initial_comment=message_text,
channel=dm_id,
file_uploads=file_uploads,
)
uploaded_files = upload["files"]
file_id = upload["files"][0]["id"]
shared = False
time.sleep(1.5)
while not shared:
latest_file_info = client.files_info(file=file_id)
shares = latest_file_info.get("file").get("shares")
shared = shares and len(shares.get("private")) > 0
if not shared:
time.sleep(0.5)

blocks = build_image_variations_result_blocks(
text=message_text,
generated_image_urls=[f["url_private"] for f in upload["files"]],
model=model,
)
client.views_update(
view_id=payload["id"],
view=build_image_variations_result_modal(blocks),
)

except (APITimeoutError, TimeoutError):
client.chat_postMessage(
channel=context.actor_user_id,
text=TIMEOUT_ERROR_MESSAGE,
)
client.views_update(
view_id=payload["id"],
view=build_image_variations_text_modal(TIMEOUT_ERROR_MESSAGE),
)
except SlackApiError as e:
logger.exception(f"Failed to call Slack APIs for image variations: {e}")
client.views_update(
view_id=payload["id"],
view=build_image_variations_text_modal(
f":warning: *My apologies!* "
f"An error occurred while calling Slack APIs: `{e}`"
),
)
except Exception as e:
logger.exception(f"Failed to share a generated image: {e}")
error = (
f"\n\n:warning: *My apologies!* "
f"An error occurred while generating image variations: `{e}`"
)
client.chat_postMessage(
channel=context.actor_user_id,
text=error,
)
client.views_update(
view_id=payload["id"],
view=build_image_variations_text_modal(error),
)


#
# Chat from scratch
#
Expand Down Expand Up @@ -953,6 +1099,14 @@ def attach_bot_scopes(client: WebClient, context: BoltContext, next_):
ack=ack_image_generation_modal_submission,
lazy=[display_image_generation_result],
)
app.action("templates-image-variations")(
ack=just_ack,
lazy=[start_image_variations],
)
app.view("image-variations")(
ack=ack_image_variations_modal_submission,
lazy=[display_image_variations_result],
)

# Free format chat
app.action("templates-from-scratch")(
Expand Down
18 changes: 18 additions & 0 deletions app/openai_image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,21 @@ def generate_image(
n=1,
)
return response.data[0].url


def generate_image_variations(
*,
context: BoltContext,
image: bytes,
size: Literal["256x256", "512x512", "1024x1024"] = "256x256",
timeout_seconds: int,
) -> str:
client = create_openai_client(context)
response = client.images.create_variation(
model="dall-e-2",
image=BytesIO(image),
size=size,
timeout=timeout_seconds,
n=1,
)
return response.data[0].url
116 changes: 115 additions & 1 deletion app/slack_ui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Optional, List
from slack_bolt import BoltContext
from slack_sdk.errors import SlackApiError
from app.i18n import translate
Expand Down Expand Up @@ -224,6 +224,7 @@ def build_home_tab(
"* Chat Templates",
"* Configuration",
"* Can you generate an image as I instruct you?",
"* Can you generate variations for my images?",
]
)
translated_sentences = list(
Expand All @@ -244,6 +245,7 @@ def build_home_tab(
chat_templates = translated_sentences[5]
configuration = translated_sentences[6]
image_generation = translated_sentences[7]
image_variations = translated_sentences[8]

blocks = []
if single_workspace_mode is False:
Expand Down Expand Up @@ -295,6 +297,16 @@ def build_home_tab(
"action_id": "templates-image-generation",
},
},
{
"type": "section",
"text": {"type": "mrkdwn", "text": image_variations or " "},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": start or " "},
"value": image_variations or " ",
"action_id": "templates-image-variations",
},
},
{
"type": "section",
"text": {"type": "mrkdwn", "text": from_scratch or " "},
Expand Down Expand Up @@ -745,6 +757,108 @@ def build_image_generation_text_modal(section_text: str) -> dict:
}


def build_image_variations_input_modal(prompt: str) -> dict:
size_options = [
{"text": {"type": "plain_text", "text": v}, "value": v}
for v in ["256x256", "512x512", "1024x1024"]
]
return {
"type": "modal",
"callback_id": "image-variations",
"title": {"type": "plain_text", "text": "Image Variations"},
"submit": {"type": "plain_text", "text": "Submit"},
"close": {"type": "plain_text", "text": "Close"},
"blocks": [
{
"type": "section",
"text": {"type": "mrkdwn", "text": prompt or " "},
},
{
"type": "input",
"block_id": "input_files",
"label": {"type": "plain_text", "text": "Files to edit"},
"element": {
"type": "file_input",
"action_id": "input",
"filetypes": ["png"],
"max_files": 5,
},
},
# https://platform.openai.com/docs/api-reference/images/create
{
"type": "input",
"block_id": "size",
"label": {"type": "plain_text", "text": "Size"},
"element": {
"type": "static_select",
"options": size_options,
"initial_option": size_options[0],
"action_id": "input",
},
},
],
}


def build_image_variations_wip_modal() -> dict:
return build_image_variations_text_modal(
"Working on this now ... :hourglass:\n\n"
"Once the images are ready, this app will send them to you in a DM. "
"If you don't want to wait here, you can close this modal at any time."
)


def build_image_variations_result_modal(blocks: list) -> dict:
return {
"type": "modal",
"callback_id": "image-variations",
"title": {"type": "plain_text", "text": "Image Variations"},
"close": {"type": "plain_text", "text": "Close"},
"blocks": blocks,
}


def build_image_variations_result_blocks(
*,
text: str,
generated_image_urls: List[str],
model: str,
) -> list[dict]:
blocks = [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": text,
},
},
]
for url in generated_image_urls:
blocks.append(
{
"type": "image",
"slack_file": {"url": url},
"alt_text": f"Generated by {model}",
}
)
return blocks


def build_image_variations_text_modal(section_text: str) -> dict:
return {
"type": "modal",
"callback_id": "image_variations",
"title": {"type": "plain_text", "text": "Image Variations"},
"close": {"type": "plain_text", "text": "Close"},
"blocks": [
{
"type": "section",
"text": {"type": "mrkdwn", "text": section_text or " "},
},
],
}


#
# From-scratch modal
#
Expand Down

0 comments on commit 6562d85

Please sign in to comment.