Skip to content

Commit

Permalink
Add structured logging for query outputs and update service (#54)
Browse files Browse the repository at this point in the history
This PR adds logging for the query outputs, so we know which documents
were fed into the context, which model was used and which answer was
returned to the user. Structured logging allows us to parse the logs
later and extract the information we are interested in.
  • Loading branch information
pcmoritz committed Sep 17, 2023
1 parent 772a8c4 commit 3cbfe26
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
6 changes: 4 additions & 2 deletions rag/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ def get_sources_and_context(query, embedding_model, num_chunks):
(embedding, num_chunks),
)
rows = cur.fetchall()
document_ids = [row[0] for row in rows]
context = [{"text": row[1]} for row in rows]
sources = [row[2] for row in rows]
return sources, context
return document_ids, sources, context


class QueryAgent:
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(

def __call__(self, query, num_chunks=5, stream=True):
# Get sources and context
sources, context = get_sources_and_context(
document_ids, sources, context = get_sources_and_context(
query=query, embedding_model=self.embedding_model, num_chunks=num_chunks
)

Expand All @@ -126,6 +127,7 @@ def __call__(self, query, num_chunks=5, stream=True):
result = {
"question": query,
"sources": sources,
"document_ids": document_ids,
"answer": answer,
"llm": self.llm,
}
Expand Down
52 changes: 40 additions & 12 deletions rag/serve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# You can run the whole script locally with
# serve run rag.serve:deployment
# serve run rag.serve:deployment --runtime-env-json='{"env_vars": {"RAY_ASSISTANT_LOGS": "/mnt/shared_storage/ray-assistant-logs/info.log", "RAY_ASSISTANT_SECRET": "ray-assistant-prod"}}'

import json
import logging
import os
import pickle
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import openai
import ray
Expand All @@ -17,6 +18,7 @@
from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from starlette.responses import StreamingResponse
import structlog

from rag.config import MAX_CONTEXT_LENGTHS, ROOT_DIR
from rag.generate import QueryAgent
Expand All @@ -37,7 +39,7 @@ def get_secret(secret_name):
import boto3

client = boto3.client("secretsmanager", region_name="us-west-2")
response = client.get_secret_value(SecretId="ray-assistant")
response = client.get_secret_value(SecretId=os.environ["RAY_ASSISTANT_SECRET"])
return json.loads(response["SecretString"])[secret_name]


Expand Down Expand Up @@ -78,6 +80,17 @@ class Answer(BaseModel):
@serve.ingress(app)
class RayAssistantDeployment:
def __init__(self, num_chunks, embedding_model_name, llm, run_slack=False):
# Configure logging
logging.basicConfig(filename=os.environ["RAY_ASSISTANT_LOGS"], level=logging.INFO, encoding='utf-8')
structlog.configure(
processors=[
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer(),
],
logger_factory=structlog.stdlib.LoggerFactory(),
)
self.logger = structlog.get_logger()

# Set credentials
os.environ["ANYSCALE_API_BASE"] = "https://api.endpoints.anyscale.com/v1"
os.environ["ANYSCALE_API_KEY"] = get_secret("ANYSCALE_API_KEY")
Expand Down Expand Up @@ -111,33 +124,48 @@ def __init__(self, num_chunks, embedding_model_name, llm, run_slack=False):
self.slack_app = SlackApp.remote()
self.runner = self.slack_app.run.remote()

@app.post("/query")
def query(self, query: Query) -> Answer:
def predict(self, query: Query, stream: bool) -> Dict[str, Any]:
use_oss_agent = self.router.predict([query.query])[0]
agent = self.oss_agent if use_oss_agent else self.gpt_agent
result = agent(query=query.query, num_chunks=self.num_chunks, stream=False)
result = agent(query=query.query, num_chunks=self.num_chunks, stream=stream)
return result

@app.post("/query")
def query(self, query: Query) -> Answer:
result = self.predict(query, stream=False)
return Answer.parse_obj(result)

def produce_streaming_answer(self, result):
def produce_streaming_answer(self, query, result):
answer = []
for answer_piece in result["answer"]:
answer.append(answer_piece)
yield answer_piece

if result["sources"]:
yield "\n\n**Sources:**\n"
for source in result["sources"]:
yield "* " + source + "\n"

self.logger.info(
"finished streaming query",
query=query,
document_ids=result["document_ids"],
llm=result["llm"],
answer="".join(answer)
)

@app.post("/stream")
def stream(self, query: Query) -> StreamingResponse:
use_oss_agent = self.router.predict([query.query])[0]
agent = self.oss_agent if use_oss_agent else self.gpt_agent
result = agent(query=query.query, num_chunks=self.num_chunks, stream=True)
result = self.predict(query, stream=True)
return StreamingResponse(
self.produce_streaming_answer(result), media_type="text/plain")
self.produce_streaming_answer(query.query, result),
media_type="text/plain"
)


# Deploy the Ray Serve app
deployment = RayAssistantDeployment.bind(
num_chunks=7,
num_chunks=5,
embedding_model_name="thenlper/gte-large",
llm="meta-llama/Llama-2-70b-chat-hf",
)
5 changes: 4 additions & 1 deletion rag/service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ cluster_env: ray-assistant
ray_serve_config:
import_path: app.serve:deployment
runtime_env:
working_dir: "https://github.com/ray-project/llm-applications/archive/refs/tags/v0.0.8.zip"
working_dir: "https://github.com/ray-project/llm-applications/archive/refs/tags/v0.0.9.zip"
env_vars:
RAY_ASSISTANT_SECRET: "ray-assistant-prod"
RAY_ASSISTANT_LOGS: "/mnt/shared_storage/ray-assistant-logs/info.log"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ray
sentence_transformers
slack_bolt
streamlit
structlog
typer
tiktoken

Expand Down

0 comments on commit 3cbfe26

Please sign in to comment.