Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Multi-modal chat #65

Open
wants to merge 3 commits into
base: feat/v1.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ simsimd = "^3.7.4"
youtube-transcript-api = "^0.6.2"
pytube = "^15.0.0"
pydub = "^0.25.1"
langchain-google-genai = "^0.0.9"

[tool.poetry.dev-dependencies]
#fasttext = {git = "https://github.com/cfculhane/fastText"} # FastText doesn't come with pybind11 and we need to use this workaround.
Expand Down
11 changes: 8 additions & 3 deletions src/wandbot/adcopy/adcopy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import random
import logging
import random
from operator import itemgetter
from typing import Any, Dict, List

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableParallel
from langchain_openai import ChatOpenAI

from wandbot.chat.chat import Chat
from wandbot.chat.schemas import ChatRequest
from wandbot.rag.utils import ChatModel
Expand Down Expand Up @@ -79,7 +80,9 @@ def build_prompt_input_variables(
self, query: str, persona: str, action: str
) -> Dict[str, Any]:
wandbot_response = self.query_wandbot(query)
additional_context = "\n".join(random.choices(self.contexts[action], k=2))
additional_context = "\n".join(
random.choices(self.contexts[action], k=2)
)
persona_prompt = (
TECHNICAL_PROMPT if persona == "technical" else EXECUTIVE_PROMPT
)
Expand Down Expand Up @@ -122,7 +125,9 @@ def _load_chain(self, model: ChatOpenAI) -> Runnable:
return chain

def __call__(self, query: str, persona: str, action: str) -> str:
logging.info(f"Generating ad copy for {persona} {action} with query: '{query}'")
logging.info(
f"Generating ad copy for {persona} {action} with query: '{query}'"
)
inputs = self.build_inputs_for_ad_formats(query, persona, action)
outputs = self.chain.batch(inputs)
str_output = ""
Expand Down
2 changes: 1 addition & 1 deletion src/wandbot/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from datetime import datetime, timezone

import pandas as pd
import wandb
from fastapi import FastAPI

import wandb
from wandbot.api.routers import adcopy as adcopy_router
from wandbot.api.routers import chat as chat_router
from wandbot.api.routers import content_navigator as content_navigator_router
Expand Down
16 changes: 9 additions & 7 deletions src/wandbot/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

import aiohttp
import requests

from wandbot.api.routers.adcopy import AdCopyRequest, AdCopyResponse
from wandbot.api.routers.chat import APIQueryRequest, APIQueryResponse
from wandbot.api.routers.content_navigator import (
ContentNavigatorRequest,
ContentNavigatorResponse,
)
from wandbot.api.routers.adcopy import AdCopyRequest, AdCopyResponse
from wandbot.api.routers.chat import APIQueryRequest, APIQueryResponse
from wandbot.api.routers.database import (
APIFeedbackRequest,
APIFeedbackResponse,
Expand Down Expand Up @@ -66,7 +67,9 @@ def __init__(self, url: str):
)
self.retrieve_endpoint = urljoin(str(self.url), "retrieve")
self.generate_ads_endpoint = urljoin(str(self.url), "generate_ads")
self.generate_content_suggestions_endpoint = urljoin(str(self.url), "generate_content_suggestions")
self.generate_content_suggestions_endpoint = urljoin(
str(self.url), "generate_content_suggestions"
)

def _get_chat_thread(
self, request: APIGetChatThreadRequest
Expand Down Expand Up @@ -640,9 +643,9 @@ async def generate_ads(
response = await self._generate_ads(request)

return response

async def _generate_content_suggestions(
self, request: ContentNavigatorRequest
self, request: ContentNavigatorRequest
) -> ContentNavigatorResponse | None:
"""Call the content navigator API.

Expand All @@ -664,7 +667,7 @@ async def _generate_content_suggestions(
return None

async def generate_content_suggestions(
self, user_id: str, query: str
self, user_id: str, query: str
) -> ContentNavigatorResponse:
"""Generates content suggestions given query.

Expand All @@ -683,4 +686,3 @@ async def generate_content_suggestions(
response = await self._generate_content_suggestions(request)

return response

1 change: 1 addition & 0 deletions src/wandbot/api/routers/adcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter
from pydantic import BaseModel
from starlette import status

from wandbot.adcopy.adcopy import AdCopyEngine


Expand Down
4 changes: 4 additions & 0 deletions src/wandbot/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from fastapi import APIRouter
from starlette import status

Expand Down Expand Up @@ -38,12 +40,14 @@ def query(
Returns:
The APIQueryResponse object containing the result of the query.
"""
logger.info(request.images[0][:10])
result = chat(
ChatRequest(
question=request.question,
chat_history=request.chat_history,
language=request.language,
application=request.application,
images=request.images,
),
)
result = APIQueryResponse(**result.model_dump())
Expand Down
16 changes: 10 additions & 6 deletions src/wandbot/api/routers/content_navigator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import httpx

import httpx
from fastapi import APIRouter
from pydantic import BaseModel
from starlette import status

CONTENT_NAVIGATOR_ENDPOINT = "https://wandb-content-navigator.replit.app/get_content"
CONTENT_NAVIGATOR_ENDPOINT = (
"https://wandb-content-navigator.replit.app/get_content"
)


class ContentNavigatorRequest(BaseModel):
"""A user query to be used by the content navigator app"""

user_id: str = None
query: str


class ContentNavigatorResponse(BaseModel):
"""Response from the content navigator app"""

Expand All @@ -26,15 +29,17 @@ class ContentNavigatorResponse(BaseModel):
)


@router.post("/", response_model=ContentNavigatorResponse, status_code=status.HTTP_200_OK)
@router.post(
"/", response_model=ContentNavigatorResponse, status_code=status.HTTP_200_OK
)
async def generate_content_suggestions(request: ContentNavigatorRequest):
async with httpx.AsyncClient(timeout=1200.0) as content_client:
response = await content_client.post(
CONTENT_NAVIGATOR_ENDPOINT,
json={"query": request.query, "user_id": request.user_id},
)
response_data = response.json()

slack_response = response_data.get("slack_response", "")
rejected_slack_response = response_data.get("rejected_slack_response", "")
response_items_count = response_data.get("response_items_count", 0)
Expand All @@ -48,4 +53,3 @@ async def generate_content_suggestions(request: ContentNavigatorRequest):
rejected_slack_response=rejected_slack_response,
response_items_count=response_items_count,
)

4 changes: 2 additions & 2 deletions src/wandbot/api/routers/database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import wandb
from fastapi import APIRouter
from starlette import status
from starlette.responses import Response

import wandb
from wandbot.database.client import DatabaseClient
from wandbot.database.database import engine
from wandbot.database.models import Base
Expand All @@ -13,8 +13,8 @@
FeedbackCreate,
QuestionAnswer,
QuestionAnswerCreate,
YoutubeAssistantThreadCreate,
YoutubeAssistantThread,
YoutubeAssistantThreadCreate,
)
from wandbot.utils import get_logger

Expand Down
6 changes: 5 additions & 1 deletion src/wandbot/apps/slack/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
from slack_bolt.async_app import AsyncApp

from wandbot.api.client import AsyncAPIClient
from wandbot.apps.slack.config import SlackAppEnConfig, SlackAppJaConfig
from wandbot.apps.slack.handlers.ad_copy import (
Expand Down Expand Up @@ -62,6 +63,7 @@
api_client = AsyncAPIClient(url=config.WANDBOT_API_URL)
slack_client = app.client


def get_init_block(user: str) -> List[Dict[str, Any]]:
initial_block = [
{
Expand Down Expand Up @@ -133,6 +135,7 @@ def get_init_block(user: str) -> List[Dict[str, Any]]:
]
return initial_block


# --------------------------------------
# Main Wandbot Mention Handler
# --------------------------------------
Expand Down Expand Up @@ -212,10 +215,11 @@ async def handle_mention(event: dict, say: callable) -> None:
)
)


async def main():
handler = AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN)
await handler.start_async()


if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
8 changes: 4 additions & 4 deletions src/wandbot/apps/slack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@
"#support チャンネルにいるwandbチームに質問してください。この答えは役に立ったでしょうか?下のボタンでお知らせ下さい。"
)

JA_ERROR_MESSAGE = "「おっと、問題が発生しました。しばらくしてからもう一度お試しください。」"

JA_FALLBACK_WARNING_MESSAGE = (
"**警告: {model}** にフォールバックします。これらの結果は **gpt-4** ほど良くない可能性があります\n\n"
JA_ERROR_MESSAGE = (
"「おっと、問題が発生しました。しばらくしてからもう一度お試しください。」"
)

JA_FALLBACK_WARNING_MESSAGE = "**警告: {model}** にフォールバックします。これらの結果は **gpt-4** ほど良くない可能性があります\n\n"


class SlackAppEnConfig(BaseSettings):
APPLICATION: str = Field("Slack_EN")
Expand Down
10 changes: 7 additions & 3 deletions src/wandbot/apps/slack/handlers/ad_copy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import logging
import re
from typing import Any, Dict, List

from slack_sdk.web.async_client import AsyncWebClient

from wandbot.api.client import AsyncAPIClient
from wandbot.apps.slack.utils import get_initial_message_from_thread

Expand Down Expand Up @@ -131,8 +132,11 @@ async def handle_adcopy_action(
query = re.sub(r"\<@\w+\>", "", query).strip()
logger.info(f"Initial message: {initial_message}")

await say(f"Working on generating ads for '{persona}' focussed on '{action}' \
for the query: '{query}'...", thread_ts=thread_ts)
await say(
f"Working on generating ads for '{persona}' focussed on '{action}' \
for the query: '{query}'...",
thread_ts=thread_ts,
)

api_response = await api_client.generate_ads(
query=query, action=action, persona=persona
Expand Down
18 changes: 12 additions & 6 deletions src/wandbot/apps/slack/handlers/content_navigator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from slack_sdk.web.async_client import AsyncWebClient

from wandbot.api.client import AsyncAPIClient
from wandbot.apps.slack.utils import get_initial_message_from_thread

Expand Down Expand Up @@ -32,15 +33,20 @@ async def handle_content_navigator_action(
if api_response.response_items_count > 0:
await say(api_response.slack_response, thread_ts=thread_ts)
else:
await say("No content suggestions found. Try rephrasing your query, but note \
await say(
"No content suggestions found. Try rephrasing your query, but note \
there may also not be any relevant pieces of content for this query. Add '--debug' to \
your query and try again to see a detailed resoning for each suggestion.",
thread_ts=thread_ts)

thread_ts=thread_ts,
)

# if debug mode is enabled, send the rejected suggestions as well
if len(api_response.rejected_slack_response) > 1:
await say("REJECTED SUGGESTIONS:\n{api_response.rejected_slack_response}", thread_ts=thread_ts)

await say(
"REJECTED SUGGESTIONS:\n{api_response.rejected_slack_response}",
thread_ts=thread_ts,
)


def create_content_navigator_handler(
slack_client: AsyncWebClient, api_client: AsyncAPIClient
Expand All @@ -57,4 +63,4 @@ async def executive_signups_handler(
logger=logger,
)

return executive_signups_handler
return executive_signups_handler
1 change: 1 addition & 0 deletions src/wandbot/apps/slack/handlers/docsbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from slack_sdk.web import SlackResponse
from slack_sdk.web.async_client import AsyncWebClient

from wandbot.api.client import AsyncAPIClient
from wandbot.apps.slack.config import SlackAppEnConfig, SlackAppJaConfig
from wandbot.apps.slack.formatter import MrkdwnFormatter
Expand Down
6 changes: 5 additions & 1 deletion src/wandbot/apps/slack/handlers/youtube_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pytube.exceptions import RegexMatchError
from slack_sdk.web.async_client import AsyncWebClient

from wandbot.apps.slack.utils import get_initial_message_from_thread
from wandbot.youtube_chat.video_utils import YoutubeVideoInfo

Expand Down Expand Up @@ -201,7 +202,10 @@ async def handle_youtube_chat_input(
await ack()
logger.info(f"Received message: {body}")
url = body["actions"][0]["value"]
await say("Working on in it...", thread_ts=body["message"]["thread_ts"],)
await say(
"Working on in it...",
thread_ts=body["message"]["thread_ts"],
)
video_confirmation_block = get_video_confirmation_blocks(url)
await say(
blocks=video_confirmation_block,
Expand Down