Skip to content

Commit

Permalink
Added support for streaming in the Assistants API
Browse files Browse the repository at this point in the history
  • Loading branch information
tatsuiman committed Mar 17, 2024
1 parent ea738a2 commit 6beef29
Show file tree
Hide file tree
Showing 12 changed files with 586 additions and 568 deletions.
466 changes: 163 additions & 303 deletions src/scripts/ai.py

Large diffs are not rendered by default.

82 changes: 17 additions & 65 deletions src/scripts/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,29 @@
import yaml
import logging
import sentry_sdk
from tempfile import mkdtemp
from sentry_sdk import set_user, set_tag
from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration
from slack_sdk import WebClient
from slack_bolt import App, Ack, BoltContext, Respond
from slack_bolt.adapter.aws_lambda import SlackRequestHandler
from slack_bolt.context import BoltContext
from blockkit import (
Divider,
Input,
Message,
PlainTextInput,
Button,
Actions,
Home,
Header,
Section,
)
from tools import add_notion_page, truncate_strings
from ui import generate_unfurl_message
from ai import CODE_INTERPRETER_EXTS
from store import Assistant
from store import get_thread_info, publish_event
from slacklib import (
get_thread_messages,
get_slack_file_bytes,
add_reaction,
delete_message,
get_im_channel_id,
BOT_USER_ID,
)
from blockkit import (
Divider,
Home,
Header,
Section,
)

# 環境変数からプロジェクト名と関数名を取得
BOT_NAME = os.getenv("BOT_NAME")
Expand Down Expand Up @@ -89,29 +82,6 @@ def generate_auto_reply_message(event):
return None


def generate_unfurl_message():
elements = [
Button(
action_id="ask_button",
text="送信",
value="ask",
style="primary",
)
]
return Message(
blocks=[
Input(
element=PlainTextInput(
action_id="ask-action",
placeholder="質問や回答内容が他の人に見られる心配はありません",
),
label="リンクの内容についてAIに質問してみましょう",
),
Actions(elements=elements),
],
).build()


@app.action("send_api_key")
def handle_send_api_key_action(ack, body, client):
ack()
Expand Down Expand Up @@ -277,35 +247,17 @@ def notion_button(ack: Ack, body: dict, action: dict, respond: Respond):
channel_id = body["channel"]["id"]
thread_ts = body["container"]["thread_ts"]
message_ts = body["container"]["message_ts"]
assistant = Assistant(user_id)
client = assistant.get_client()
add_reaction("eyes", channel_id, message_ts)
set_user({"id": user_id})
try:
thread_messages = get_thread_messages(channel_id, thread_ts)
if len(thread_messages) == 0:
return
message_content = thread_messages[-1].get("text")
message = "Notionページを作成しています。\nしばらくお待ちください..."
respond(message)
# 最初の1kトークンからタイトルを決める
truncate_message_content = truncate_strings(message_content, max_tokens=1000)
title = client.generate_title(truncate_message_content)
slack_url = (
f"https://slack.com/archives/{channel_id}/p{thread_ts.replace('.', '')}"
)
# ファイルタイプがテキストならダウンロードする
for file in thread_messages[-1].get("files", []):
if file["mimetype"] != "text/plain":
continue
url_private = file["url_private_download"]
file_data = get_slack_file_bytes(url_private)
message_content += f"{file_data.decode()}\n"

# メッセージ内容のNotionページを作成
result = add_notion_page(title, message_content, slack_url)
respond(result)
except Exception as e:
respond(f"Notionページの作成に失敗しました。{e}")
event = {
"user": user_id,
"type": "message",
"channel": channel_id,
"thread_ts": thread_ts,
"text": "会話の内容をNotionにまとめてください。",
"ts": message_ts,
}
publish_event(event)


@app.action("delete_button")
Expand Down
166 changes: 166 additions & 0 deletions src/scripts/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
import sys
import time
import json
import logging
import sentry_sdk
from pluginbase import PluginBase
from slacklib import post_message
from plugin import handle_output_plugin
from slacklib import post_message, update_message, upload_file
from ui import generate_faq_block, generate_step_block
from tools import truncate_strings, calculate_token_size

STREAM_RATE = 1
SLACK_MAX_TOKEN_SIZE = 1500

# PluginBase インスタンスを作成
plugin_base = PluginBase(package="plugins")
# プラグインのソースを作成(プラグインが置かれるディレクトリを指定)
function_source = plugin_base.make_plugin_source(searchpath=["./functions"])


# Slackへのメッセージコールバックを処理するクラスです。
class MessageCallback:
def __init__(self, channel_id, message_ts) -> None:
self.count = 0
self.message = ""
self.channel_id = channel_id
self.message_ts = message_ts
self.ts = 0
self.last_update_time = 0
self.current_message = ""

# メッセージの作成を行います。
def create(self) -> None:
res = post_message(self.channel_id, self.message_ts, "Message Typing...")
self.ts = res["ts"]

# ファイルを設定します。
def set_files(self, files):
message = ""
for file in files:
filename = os.path.basename(file)
permalink = upload_file(file, filename)
message += f"<{permalink}|{filename}>\n"
post_message(self.channel_id, self.message_ts, message)

# メッセージを更新します。
def update(self, message: str) -> None:
current_time = time.time()
output_token = calculate_token_size(self.current_message + message)
if output_token > SLACK_MAX_TOKEN_SIZE or self.ts == 0:
if output_token > SLACK_MAX_TOKEN_SIZE:
update_message(self.channel_id, self.ts, self.current_message)
self.create()
self.current_message = message
else:
self.current_message += message

if current_time - self.last_update_time >= STREAM_RATE:
update_message(self.channel_id, self.ts, self.current_message)
self.last_update_time = current_time

# メッセージの処理が完了したときの処理を行います。
def done(self, message):
files = handle_output_plugin(message)
self.set_files(files)
update_message(self.channel_id, self.ts, self.current_message)
self.ts = 0

# メッセージの終了処理を行います。
def end(self):
blocks = generate_faq_block()
post_message(self.channel_id, self.message_ts, blocks=blocks)


# ステップコールバックを処理するクラスです。
class StepCallback:
def __init__(self, channel_id, message_ts) -> None:
self.output = "none"
self.channel_id = channel_id
self.message_ts = message_ts
self.ts = 0
self.last_update_time = 0
self.current_message = ""

# コード入力の開始を通知します。
def create(self) -> None:
res = post_message(self.channel_id, self.message_ts, "Code Typing...")
self.ts = res["ts"]

# コードメッセージを更新します。
def _update_code_message(self, message):
message = f"コード\n```\n{message}\n```\n"
update_message(self.channel_id, self.ts, message)

# メッセージを更新します。
def update(self, message: str) -> None:
current_time = time.time()
output_token = calculate_token_size(self.current_message + message)
if output_token > SLACK_MAX_TOKEN_SIZE or self.ts == 0:
if output_token > SLACK_MAX_TOKEN_SIZE:
self._update_code_message(self.current_message)
self.create()
self.current_message = message
else:
self.current_message += message

if current_time - self.last_update_time >= STREAM_RATE:
self._update_code_message(self.current_message)
self.last_update_time = current_time

# 関数呼び出しを行います。
def function_call(self, function_name, arguments) -> None:
self.create()
argument = json.loads(arguments)
argument_truncated = truncate_strings(arguments, max_tokens=20)
message = f"call: `{function_name}({argument_truncated})`\n"
update_message(self.channel_id, self.ts, message)

# プラグインとその優先度を格納するリスト
plugins_with_priority = []
for plugin_name in function_source.list_plugins():
# プラグインの優先度を取得(デフォルトは最低優先度)
plugin_module = function_source.load_plugin(plugin_name)
priority = getattr(plugin_module, "PRIORITY", float("inf"))
# プラグインとその優先度をリストに追加
plugins_with_priority.append((priority, plugin_name))
result = ""
# 優先度に基づいてプラグインをソート
plugins_with_priority.sort()
# ソートされた順序でプラグインを実行
for _, plugin_name in plugins_with_priority:
# プラグインをロード
plugin_module = function_source.load_plugin(plugin_name)
if function_name == plugin_name:
try:
logging.info(f"run extractor {plugin_name}")
# プラグインモジュールから関数を呼び出す
plugin_result = plugin_module.run(**argument)
# 8kになるように切り捨てする
result = truncate_strings(plugin_result, max_tokens=16000)
except Exception as e:
sentry_sdk.capture_exception(e)
logging.exception(e)
break
if len(result) == 0:
result = "no result"
output_token = calculate_token_size(result)
message += (
f"`{function_name}`を実行しました。結果のトークン数: {output_token}\n"
)
update_message(self.channel_id, self.ts, message)
return result

# 出力を設定します。
def set_output(self, output: str) -> None:
self.output = output

# 処理が完了したときの処理を行います。
def done(self):
message = f"コード\n```\n{self.current_message}\n```\n実行結果\n```\n{self.output}\n```\n"
blocks = generate_step_block(message)
update_message(self.channel_id, self.ts, blocks=blocks)
self.ts = 0
self.output = "none"
10 changes: 10 additions & 0 deletions src/scripts/data/assistant.yara
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ rule open_notion_url
($url and $scheme and $domain)
}

rule create_notion_page
{
strings:
$keyword1 = /notion/i
$keyword2 = "作成"
$keyword3 = "まとめて"
condition:
($keyword1 and $keyword2) or ($keyword1 and $keyword3)
}

rule open_slack_url
{
strings:
Expand Down
21 changes: 21 additions & 0 deletions src/scripts/data/assistant.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ open_youtube_url:
type: object
type: function

create_notion_page:
name: Create Notion Page
instructions: |
Notionにまとめる必要がある場合はページを作成します。
tools:
- function:
description: Create Notion Page
name: create_notion_page
parameters:
properties:
title:
description: page title
type: string
content:
description: markdown text
type: string
required:
- url
type: object
type: function

open_notion_url:
name: Open Notion URL
instructions: |
Expand Down
24 changes: 24 additions & 0 deletions src/scripts/functions/create_notion_page.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
from notion.util import (
create_notion_page,
markdown_to_notion_blocks,
append_blocks_to_page,
)

# NotionデータベースID
DATABASE_ID = os.getenv("DATABASE_ID")


def run(title, content):
# 新しいページを作成
properties = {}
resp = create_notion_page(DATABASE_ID, title, properties)
try:
# MarkdownをNotionブロックに変換
blocks = markdown_to_notion_blocks(content)
# Notionページにブロックを追加
append_blocks_to_page(resp["id"], blocks)
url = resp["url"]
return f"\nNotionページを作成しました。\n\n間違いがあれば修正し、ページは適当な場所に移動してください。\n{title}\n{url}"
except:
return f"\nNotionページの追加に失敗しました。{resp}"
Loading

0 comments on commit 6beef29

Please sign in to comment.