/
app.py
153 lines (123 loc) · 4.83 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
import logging
import os
import re
import time
from datetime import timedelta
from typing import Any
from add_document import initialize_vectorstore
from dotenv import load_dotenv
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory, MomentoChatMessageHistory
from langchain.schema import LLMResult
from slack_bolt import App
from slack_bolt.adapter.aws_lambda import SlackRequestHandler
from slack_bolt.adapter.socket_mode import SocketModeHandler
CHAT_UPDATE_INTERVAL_SEC = 1
load_dotenv()
# ログ
SlackRequestHandler.clear_all_log_handlers()
logging.basicConfig(
format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
# ボットトークンを使ってアプリを初期化します
app = App(
signing_secret=os.environ["SLACK_SIGNING_SECRET"],
token=os.environ["SLACK_BOT_TOKEN"],
process_before_response=True,
)
class SlackStreamingCallbackHandler(BaseCallbackHandler):
last_send_time = time.time()
message = ""
def __init__(self, channel, ts):
self.channel = channel
self.ts = ts
self.interval = CHAT_UPDATE_INTERVAL_SEC
# 投稿を更新した累計回数カウンタ
self.update_count = 0
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.message += token
now = time.time()
if now - self.last_send_time > self.interval:
app.client.chat_update(
channel=self.channel, ts=self.ts, text=f"{self.message}\n\nTyping..."
)
self.last_send_time = now
self.update_count += 1
# update_countが現在の更新間隔X10より多くなるたびに更新間隔を2倍にする
if self.update_count / 10 > self.interval:
self.interval = self.interval * 2
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
message_context = "OpenAI APIで生成される情報は不正確または不適切な場合がありますが、当社の見解を述べるものではありません。"
message_blocks = [
{"type": "section", "text": {"type": "mrkdwn", "text": self.message}},
{"type": "divider"},
{
"type": "context",
"elements": [{"type": "mrkdwn", "text": message_context}],
},
]
app.client.chat_update(
channel=self.channel,
ts=self.ts,
text=self.message,
blocks=message_blocks,
)
# @app.event("app_mention")
def handle_mention(event, say):
channel = event["channel"]
thread_ts = event["ts"]
message = re.sub("<@.*>", "", event["text"])
# 投稿のキー(=Momentoキー):初回=event["ts"],2回目以降=event["thread_ts"]
id_ts = event["ts"]
if "thread_ts" in event:
id_ts = event["thread_ts"]
result = say("\n\nTyping...", thread_ts=thread_ts)
ts = result["ts"]
history = MomentoChatMessageHistory.from_client_params(
id_ts,
os.environ["MOMENTO_CACHE"],
timedelta(hours=int(os.environ["MOMENTO_TTL"])),
)
memory = ConversationBufferMemory(
chat_memory=history, memory_key="chat_history", return_messages=True
)
vectorstore = initialize_vectorstore()
callback = SlackStreamingCallbackHandler(channel=channel, ts=ts)
llm = ChatOpenAI(
model_name=os.environ["OPENAI_API_MODEL"],
temperature=os.environ["OPENAI_API_TEMPERATURE"],
streaming=True,
callbacks=[callback],
)
condense_question_llm = ChatOpenAI(
model_name=os.environ["OPENAI_API_MODEL"],
temperature=os.environ["OPENAI_API_TEMPERATURE"],
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory,
condense_question_llm=condense_question_llm,
)
qa_chain.run(message)
def just_ack(ack):
ack()
app.event("app_mention")(ack=just_ack, lazy=[handle_mention])
# ソケットモードハンドラーを使ってアプリを起動します
if __name__ == "__main__":
SocketModeHandler(app, os.environ["SLACK_APP_TOKEN"]).start()
def handler(event, context):
logger.info("handler called")
header = event["headers"]
logger.info(json.dumps(header))
if "x-slack-retry-num" in header:
logger.info("SKIP > x-slack-retry-num: %s", header["x-slack-retry-num"])
return 200
# AWS Lambda 環境のリクエスト情報を app が処理できるよう変換してくれるアダプター
slack_handler = SlackRequestHandler(app=app)
# 応答はそのまま AWS Lambda の戻り値として返せます
return slack_handler.handle(event, context)