In [227]:
from pydantic import BaseModel
from typing import Optional, List, Tuple, Literal
import os
import dotenv

In [228]:
dotenv.load_dotenv(".env")

True

In [229]:
class ChatMessageSource(BaseModel):
    """Base class representing a source of information for a message."""

    sourceId: str
    messageId: str
    pageContent: str
    metadata: dict
    createdAt: int


class ChatMessage(BaseModel):
    """Base class representing a chat message."""

    messageId: str
    messageType: str
    userId: str
    chatId: str
    content: str
    createdAt: int
    sources: Optional[List[ChatMessageSource]] = None


class ChatMessage(BaseModel):
    """Base class representing a chat message."""

    messageId: str
    messageType: str
    userId: str
    chatId: str
    content: str
    createdAt: int
    sources: Optional[List[ChatMessageSource]] = None


class ModelKwargs(BaseModel):
    maxTokens: Optional[int] = None
    temperature: Optional[float] = None
    topP: Optional[float] = None
    stopSequences: Optional[List[str]] = None


# Define the ModelProvider type
ModelProvider = Literal["sagemaker", "bedrock"]


class ModelKwargs(BaseModel):
    maxTokens: Optional[int] = None
    temperature: Optional[float] = None
    topP: Optional[float] = None
    stopSequences: Optional[list[str]] = None


class ModelBase(BaseModel):
    provider: ModelProvider
    modelId: str
    region: Optional[str] = None


class LLMModelBase(ModelBase):
    modelKwargs: Optional[ModelKwargs] = None


class BedRockLLMModel(LLMModelBase):
    provider: Literal["bedrock"]


class HandoffConfig(BedRockLLMModel):
    details: Optional[list[str]] = None
    windowSize: Optional[int] = None
    windowOverlap: Optional[int] = None


# interface HandoffConfig extends BedRockLLMModel {
#     details?: string[];
# }

In [230]:
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage
from langchain_core.messages.utils import get_buffer_string
from typing import Optional, Iterator
import boto3

FAILED_TO_SUMMARIZE = "Summarizer failed to generate a response."

DEFAULT_MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 0.90
DEFAULT_STOP_SEQUENCES = []


class Summarizer:
    def __init__(self, handoff_config: HandoffConfig):
        self.handoff_config = handoff_config
        self.bedrock = boto3.client("bedrock-runtime")
        self.model_id = handoff_config.modelId

        # A model prompt consists of a role definition, a prompt describing the summarization task,
        # a conversation, and a tail prompt (which includes a list of types of details to focus on)

        self.role_definition = str(
            "You are a detailed note-taker for a customer service chatbot that helps "
            "users solve technical issues. You carefully read through conversations "
            "and focus on the details you're asked to find by the system. Do not "
            "output anything except text."
        )

        # For extending a summary
        self.recursive_prompt = lambda summary: (
            f"This is a summary of the conversation so far:\n\n{summary}\n\n"
            "Extend the summary by taking into account the new messages below. "
            "Be sure to return a summary of the whole conversation, not just the "
            "new messages. Use bullet points, being sure to add new bullet points to "
            "reflect these new messages. Feel free to modify the summary to make it more "
            "concise, but make sure to keep all the important details.",
        )

        # For creating a new summary
        self.base_prompt: str = str(
            "Create a summary of the conversation below that captures the key "
            "points of the conversation. Keep the summary to a few bullet points."
        )

        # Separates instructions from conversation in the prompt
        self.conv_start_delim = "CONVERSATION START:"
        self.conv_end_delim = "CONVERSATION END"

        # Create a non-empty tail prompt if config provides details to focus on
        self.types_of_details = self._details_string(handoff_config.details)
        self.tail_prompt: str = (
            (
                "Especially focus on the following types of details:\n"
                f"{self.types_of_details}\n"
                "...as well as any other details that are important to the purpose of the conversation."
            )
            if handoff_config.details
            else ""
        )

    def _details_string(self, details: list[str]) -> str:
        return "\n".join([f"- {detail}" for detail in details])

    def _conversator_name(self, message_type: Literal["ai", "human"]) -> str:
        match message_type:
            case "ai":
                return "Chatbot"
            case "human":
                return "Human"
            case _:
                raise ValueError(f"Unknown message type: {message_type}")

    def _create_conversation_string(self, messages: Iterator[ChatMessage]) -> str:
        return "\n\n".join(
            [f"{self._conversator_name(m.messageType)}: {m.content}" for m in messages]
        )

    def _create_summarization_prompt(
        self,
        messages: Iterator[ChatMessage],
        existing_summary=None,
        supports_system_prompt=False,
    ) -> dict:

        if existing_summary:
            task_prompt = self.recursive_prompt(existing_summary)
        else:
            task_prompt = self.base_prompt

        conversation = self._create_conversation_string(messages)

        complete_prompt_components = [
            task_prompt,
            self.conv_start_delim,
            conversation,
            self.conv_end_delim,
            self.tail_prompt,
        ]

        if supports_system_prompt:
            # Add the role definition to the system prompts
            complete_prompt = "\n\n".join(complete_prompt_components)
            return {"prompt": complete_prompt, "system_prompts": [self.role_definition]}
        else:
            # Add the role definition to the user prompt
            complete_prompt = "\n\n".join(
                [self.role_definition] + complete_prompt_components
            )
            return {"prompt": complete_prompt}

    def _non_text_response_types(self, response: dict) -> set:
        response_content = response["output"]["message"]["content"]
        content_types = set()
        for content in response_content:
            content_types.update(content.keys())
        return content_types - {"text"}

    def summarize(self, francis_messages: Iterator[ChatMessage]) -> Optional[str]:
        prompt = self._create_summarization_prompt(francis_messages)

        messages = [{"role": "user", "content": [{"text": prompt["prompt"]}]}]
        system_prompts = prompt.get("system_prompts", None)

        inference_config = {
            "maxTokens": self.handoff_config.modelKwargs.maxTokens
            or DEFAULT_MAX_TOKENS,
            "temperature": self.handoff_config.modelKwargs.temperature
            or DEFAULT_TEMPERATURE,
            "topP": self.handoff_config.modelKwargs.topP or DEFAULT_TOP_P,
            "stopSequences": self.handoff_config.modelKwargs.stopSequences
            or DEFAULT_STOP_SEQUENCES,
        } | (
            {"region": self.handoff_config.region} if self.handoff_config.region else {}
        )

        converse_kwargs = {
            "modelId": self.handoff_config.modelId,
            "messages": messages,
            "inferenceConfig": inference_config,
        } | ({"systemPrompts": system_prompts} if system_prompts else {})
        # Some models support system prompts; some do not

        try:
            response = self.bedrock.converse(**converse_kwargs)
        except Exception as e:
            print(f"Error while summarizing messages: {e}")
            return FAILED_TO_SUMMARIZE

        if response.get("stopReason") not in [
            "end_turn",
            "stop_sequence",
            "max_tokens",
        ]:
            print(
                f"Unexpected stop reason from model {self.model_id}: {response.get('stop_reason')}"
            )
            return FAILED_TO_SUMMARIZE

        if non_text_types := self._non_text_response_types(response):
            # TODO: log this
            print(
                f"Unexpected response mode; did not expect non-text content: {non_text_types}"
            )

        # Aggregate all text outputs from the response
        response_content = response["output"]["message"]["content"]
        text_outputs = [
            content.get("text") for content in response_content if "text" in content
        ]

        return "\n\n".join(text_outputs)

In [231]:
index_name = "GSI1"
dynamodb = boto3.resource("dynamodb")
TABLE_NAME = "FrancisChatbotStack-nasrullah-dev-ConversationStoreConversationTable631357AC-CI2U1TKWQOEW"
table = dynamodb.Table(TABLE_NAME)


def get_chat_messages_by_time_key(
    user_id: str, chat_id: str, timestamp: str = ""
) -> dict:
    return {
        "GSI1PK": f"{user_id}#CHAT#{chat_id}",
        "GSI1SK": timestamp,
    }


def parse_next_token(next_token: str) -> dict | None:
    parts = next_token.split("|")
    if len(parts) != 4:
        print(f"Invalid next_token format: {next_token}")
        return None

    PK, SK, GSI1PK, GSI1SK = parts
    return {
        "PK": PK,
        "SK": SK,
        "GSI1PK": GSI1PK,
        "GSI1SK": GSI1SK if GSI1SK else None,
    }


def generate_next_token(params: dict) -> str:
    return (
        f"{params['PK']}|{params['SK']}|{params['GSI1PK']}|{params.get('GSI1SK', '')}"
    )


def list_chat_messages(
    user_id: str,
    chat_id: str,
    next_token: Optional[str] = None,
    limit: int = 50,
    ascending: bool = True,
) -> Tuple[List[ChatMessage], Optional[str]]:
    keys = get_chat_messages_by_time_key(user_id, chat_id, "")

    exclusive_start_key = {}
    if isinstance(next_token, str):
        parsed_token = parse_next_token(next_token)
        exclusive_start_key = (
            {"ExclusiveStartKey": parsed_token} if parsed_token else {}
        )

    query_input = {
        "IndexName": index_name,
        "KeyConditionExpression": "GSI1PK = :PK",
        "ExpressionAttributeValues": {":PK": keys["GSI1PK"]},
        "Limit": limit,
        "ScanIndexForward": ascending,
        **exclusive_start_key,
    }

    response = table.query(**query_input)
    messages = [
        ChatMessage(
            chatId=record["chatId"],
            userId=record["userId"],
            messageId=record["messageId"],
            content=record["data"]["content"],
            createdAt=int(record["createdAt"]),
            messageType=record["messageType"],
        )
        for record in response.get("Items", [])
    ]

    next_new_token = None
    if response.get("LastEvaluatedKey"):
        next_new_token = generate_next_token(response["LastEvaluatedKey"])

    return messages, next_new_token

In [232]:
details = [
    "The user's primary issue",
    "Questions that the user asked",
    "Places where the user seemed confused",
    "Places where the user got stuck",
    "Recommendations the AI made",
    "Solutions that did or didn't work",
    "Places where the user seemed frustrated or wanted to talk to a human",
]

In [233]:
MODEL = "anthropic.claude-3-haiku-20240307-v1:0"
# MODEL = "us.meta.llama3-3-70b-instruct-v1:0"
handoff_config = HandoffConfig(
    provider="bedrock",
    region=None,
    modelId=MODEL,
    modelKwargs=ModelKwargs(
        temperature=0.1,
        maxTokens=1024,
        topP=0.90,
    ),
    details=details,
)
summarizer = Summarizer(handoff_config)

In [234]:
os.environ["AWS_PROFILE"] = "usda"
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"
!echo $AWS_PROFILE
!echo $AWS_DEFAULT_REGION

usda
us-west-2


In [235]:
chat_id = "9112a4c0-6a2e-4a56-9280-5e25c95f8888"
user_id = "9811f3a0-d0d1-701b-29e4-f7a680d50e85"
messages, token = list_chat_messages(user_id, chat_id)
messages

[ChatMessage(messageId='819fa771-b117-443c-b55f-9a3aaaefd077', messageType='human', userId='9811f3a0-d0d1-701b-29e4-f7a680d50e85', chatId='9112a4c0-6a2e-4a56-9280-5e25c95f8888', content='hello\n', createdAt=1737745583521, sources=None),
 ChatMessage(messageId='254788aa-9114-470c-be8e-dac47d63c1dc', messageType='ai', userId='9811f3a0-d0d1-701b-29e4-f7a680d50e85', chatId='9112a4c0-6a2e-4a56-9280-5e25c95f8888', content="Hello! I'm the Cal Poly IT Help Desk assistant. How can I help you today?", createdAt=1737745583567, sources=None),
 ChatMessage(messageId='eb7cf1f8-9f3c-49ea-a5e0-95d6273d5053', messageType='human', userId='9811f3a0-d0d1-701b-29e4-f7a680d50e85', chatId='9112a4c0-6a2e-4a56-9280-5e25c95f8888', content='how do I change my display name?', createdAt=1737745609388, sources=None),
 ChatMessage(messageId='5f340136-3dc0-48c8-a4d8-0f51e6e94d04', messageType='ai', userId='9811f3a0-d0d1-701b-29e4-f7a680d50e85', chatId='9112a4c0-6a2e-4a56-9280-5e25c95f8888', content='Here are the step

In [236]:
prompt = summarizer._create_summarization_prompt(messages)

# write prompt to a file
with open("prompt.txt", "w") as f:
    f.write(prompt["prompt"])

In [None]:
prompts = summarizer._create_summarization_prompt(messages)
prompt, system_prompts = prompts["prompt"], prompts.get("system_prompts", None)

{'prompt': "You are a detailed note-taker for a customer service chatbot that helps users solve technical issues. You carefully read through conversations and focus on the details you're asked to find by the system. Do not output anything except text.\n\nCreate a summary of the conversation below that captures the key points of the conversation. Keep the summary to a few bullet points.\n\nCONVERSATION START:\n\nHuman: hello\n\n\nChatbot: Hello! I'm the Cal Poly IT Help Desk assistant. How can I help you today?\n\nHuman: how do I change my display name?\n\nChatbot: Here are the steps to change your display name when sending email from Outlook:\n\n1. Log in to the My Cal Poly Portal > Personal Info tab.\n2. Under Personal Info > My Info, by the Preferred Name field click the Edit link.\n3. Click the Continue link, then:\n   a. Employees: Name Type > Preferred, click the Edit button.\n   b. Students: Personal Info > Names, click the Edit button.\n4. Enter your name using upper and lower c

In [237]:
summary = summarizer.summarize(messages)
summary

"Summary:\n\n- The user's primary issue was how to change their display name when sending email from Outlook.\n- The chatbot provided step-by-step instructions on how to update the preferred name in the My Cal Poly Portal.\n- The steps included logging into the portal, navigating to the Personal Info tab, and editing the Preferred Name field.\n- The chatbot noted that it may take up to 24 hours for the preferred name change to take effect.\n- The user did not seem confused or frustrated during the conversation, and the provided solution appeared to address their issue."

In [238]:
from datetime import datetime


def save_summary(summary: str, model_name: str) -> None:
    if not os.path.exists(save_summary.dir_name):
        os.makedirs(save_summary.dir_name)

    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    path = os.path.join(save_summary.dir_name, f"{model_name}-{timestamp}.txt")
    with open(path, "w") as f:
        f.write(summary)


save_summary.dir_name = "summaries"

In [239]:
save_summary(summary, MODEL)