Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 103 additions & 8 deletions src/steamship/agents/mixins/transports/slack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import urllib.parse
from enum import Enum
from typing import List, Optional

import requests
Expand All @@ -19,6 +20,42 @@
SETTINGS_KVSTORE_KEY = "slack-transport"


class SlackContextBehavior(Enum):
"""Defines how history between agent and users is tracked.

These specifications are specifically in regard to how the agent interacts with Slack as it pertains to Agent
Context.
"""

ENTIRE_CHANNEL = "entire-channel"
"""
Agent context is per channel as a whole, which includes bot mentions sent to the top level channel, and across *any*
thread in that channel.
"""

THREADS_ARE_NEW_CONVERSATIONS = "threads-are-new-conversations"
"""
Agent context is thread-aware. The top level channel is treated as its own context, and threads have their own
contexts.
"""


class SlackThreadingBehavior(Enum):
"""Defines how responses from the agent will be delivered in response to user mentions."""

FOLLOW_THREADS = "follow-threads"
"""
If the bot is mentioned from the top-level channel, the response will be in the channel. If the bot is mentioned
from within a thread, the response will be to that thread.
"""

ALWAYS_THREADED = "always-threaded"
"""
Responses from the bot will always be threaded. If the bot was mentioned at the top level of the channel, a new
thread will be created for the response.
"""


class SlackElement(BaseModel):
"""An element of a Slack Block."""

Expand Down Expand Up @@ -95,6 +132,9 @@ class SlackEvent(BaseModel):
ts: Optional[str] = Field(
description="Timestamp of the message. A string, but is a floating point number within that."
)
thread_ts: Optional[str] = Field(
description="Timestamp of the thread this message is a part of, if any. Same format as `ts`."
)
item: Optional[str] = Field(
description="Data specific to the underlying object type being described."
)
Expand All @@ -121,8 +161,13 @@ def to_blocks(self) -> Optional[List[Block]]:
for slack_block in self.blocks or []:
for block in slack_block.to_blocks() or []:
if self.channel:
block.set_chat_id(str(self.channel))
# TODO: Do we want to encode other things like the tab, user, etc?
block.set_chat_id(self.channel)
if self.ts:
block.set_message_id(self.ts)
if self.user:
block.set_user_id(self.user)
if self.thread_ts:
block.set_thread_id(self.thread_ts)
ret.append(block)
return ret

Expand Down Expand Up @@ -159,6 +204,14 @@ class SlackTransportConfig(Config):
slack_api_base: str = Field(
SLACK_API_BASE, description="Slack API base URL. If blank defaults to production Slack."
)
threading_behavior: SlackThreadingBehavior = Field(
SlackThreadingBehavior.FOLLOW_THREADS.value,
description="Whether the bot will always respond in threads, or only if the invocation was threaded",
)
context_behavior: SlackContextBehavior = Field(
SlackContextBehavior.ENTIRE_CHANNEL.value,
description="Whether the bot will be provided conversation context from the channel as a whole, or per thread.",
)


class SlackTransport(Transport):
Expand Down Expand Up @@ -284,6 +337,7 @@ def _send(self, blocks: List[Block], metadata: Metadata): # noqa: C901
text = None
slack_blocks = []
chat_id = None
thread_ts = None
for block in blocks:
# This is required for the public_url creation below.

Expand All @@ -292,6 +346,9 @@ def _send(self, blocks: List[Block], metadata: Metadata): # noqa: C901
if block.chat_id:
chat_id = block.chat_id

if block.thread_id:
thread_ts = block.thread_id

if block.is_text() or block.text:
if not text:
# This is the fallback for mobile notifications
Expand Down Expand Up @@ -346,6 +403,8 @@ def _send(self, blocks: List[Block], metadata: Metadata): # noqa: C901
"text": text, # This is for mobile previews. The "block" key has the real content.
"channel": chat_id,
}
if thread_ts:
body["thread_ts"] = thread_ts

post_url = f"{self.config.slack_api_base}chat.postMessage"

Expand All @@ -355,28 +414,63 @@ def _send(self, blocks: List[Block], metadata: Metadata): # noqa: C901
json=body,
)

def build_emit_func(self, chat_id: str) -> EmitFunc:
def build_emit_func(
self, chat_id: str, incoming_message_ts: str, thread_ts: Optional[str]
) -> EmitFunc:
"""Return an EmitFun that sends messages to the appropriate Slack channel."""
if self.config.threading_behavior == SlackThreadingBehavior.FOLLOW_THREADS.value:
reply_thread_ts = thread_ts
elif self.config.threading_behavior == SlackThreadingBehavior.ALWAYS_THREADED.value:
reply_thread_ts = thread_ts or incoming_message_ts
else:
raise ValueError(f"Unhandled threading behavior: {self.config.threading_behavior}")

def new_emit_func(blocks: List[Block], metadata: Metadata):
for block in blocks:
block.set_chat_id(chat_id)
if reply_thread_ts:
block.set_thread_id(reply_thread_ts)
return self.send(blocks, metadata)

return new_emit_func

def _get_context_id_for_response(self, channel: str, thread_ts: Optional[str]) -> str:
if self.config.context_behavior == SlackContextBehavior.ENTIRE_CHANNEL.value:
return channel
elif (
self.config.context_behavior == SlackContextBehavior.THREADS_ARE_NEW_CONVERSATIONS.value
):
return channel if not thread_ts else f"{channel}-{thread_ts}"
else:
raise ValueError(f"Unhandled context behavior: {self.config.context_behavior}")

def _respond_to_block(self, incoming_message: Block):
"""Respond to a single inbound message from Slack, posting the response back to Slack."""
try:
chat_id = incoming_message.chat_id
context = self.agent_service.build_default_context(context_id=chat_id)
thread_ts = incoming_message.thread_id
context_id = self._get_context_id_for_response(chat_id, thread_ts)
context = self.agent_service.build_default_context(context_id=context_id)

context.chat_history.append_user_message(
text=incoming_message.text, tags=incoming_message.tags
)

context.metadata["slack"] = {
"channel": chat_id,
"message_ts": incoming_message.message_id,
}
if thread_ts:
context.metadata["slack"]["thread_ts"] = thread_ts

# TODO: For truly async support, this emit fn will need to be wired in at the Agent level.
context.emit_funcs = [self.build_emit_func(chat_id=chat_id)]
context.emit_funcs = [
self.build_emit_func(
chat_id=chat_id,
incoming_message_ts=incoming_message.message_id,
thread_ts=thread_ts,
)
]

# Add an LLM to the context, using the Agent's if it exists.
llm = None
Expand Down Expand Up @@ -413,9 +507,10 @@ def slack_respond_sync(self, **kwargs) -> InvocableResponse[str]: # noqa: C901
if slack_request.event:
if slack_request.event.bot_id is None:
if slack_request.event.is_message():
logging.info(
f"User {slack_request.event.user} sent message in channel {slack_request.event.channel}"
)
log_message = f"User {slack_request.event.user} sent message in channel {slack_request.event.channel}"
if slack_request.event.thread_ts:
log_message += f" from within thread {slack_request.event.thread_ts}"
logging.info(log_message)
incoming_messages = slack_request.event.to_blocks()
for incoming_message in incoming_messages:
if incoming_message is not None:
Expand Down
22 changes: 22 additions & 0 deletions src/steamship/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,28 @@ def set_chat_id(self, chat_id: str):
tag_kind=DocTag.CHAT, tag_name=ChatTag.CHAT_ID, string_value=chat_id
)

@property
def thread_id(self) -> Optional[str]:
return get_tag_value_key(
self.tags, TagValueKey.STRING_VALUE, kind=DocTag.CHAT, name=ChatTag.THREAD_ID
)

def set_thread_id(self, thread_id: str) -> None:
return self._one_time_set_tag(
tag_kind=DocTag.CHAT, tag_name=ChatTag.THREAD_ID, string_value=thread_id
)

@property
def user_id(self) -> Optional[str]:
return get_tag_value_key(
self.tags, TagValueKey.STRING_VALUE, kind=DocTag.CHAT, name=ChatTag.USER_ID
)

def set_user_id(self, user_id: str) -> None:
return self._one_time_set_tag(
tag_kind=DocTag.CHAT, tag_name=ChatTag.USER_ID, string_value=user_id
)

def _one_time_set_tag(self, tag_kind: str, tag_name: str, string_value: str):
existing = get_tag_value_key(
self.tags, TagValueKey.STRING_VALUE, kind=tag_kind, name=tag_name
Expand Down
6 changes: 6 additions & 0 deletions src/steamship/data/tags/tag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ class ChatTag(str, Enum):
# The message id of a message
MESSAGE_ID = "message-id"

# In environments which support threading, the thread id where the message occurred
THREAD_ID = "thread-id"

# In multiuser environments, the ID of the user who created the message
USER_ID = "user-id"

# The role of a message
ROLE = "role"

Expand Down