diff --git a/pyproject.toml.jinja b/pyproject.toml.jinja index 9d9e42a..84f2118 100644 --- a/pyproject.toml.jinja +++ b/pyproject.toml.jinja @@ -1,5 +1,5 @@ [project] -name = "{{ project_name_snake }}" +name = "{{ project_name }}" version = "0.1.0" description = "Add your description here" readme = "README.md" @@ -7,7 +7,7 @@ authors = [] requires-python = ">=3.12" dependencies = [ "llama-index-workflows>=2.2.0,<3.0.0", - "llama-cloud-services>=0.6.68", + "llama-cloud-services>=0.6.69", "llama-index-core>=0.14.0", "llama-index-llms-openai>=0.5.6", "llama-index-embeddings-openai>=0.5.1", diff --git a/src/{{ project_name_snake }}/clients.py b/src/{{ project_name_snake }}/clients.py index 9f8d0d8..ee67170 100644 --- a/src/{{ project_name_snake }}/clients.py +++ b/src/{{ project_name_snake }}/clients.py @@ -3,7 +3,8 @@ import httpx from llama_cloud.client import AsyncLlamaCloud -from llama_cloud_services import LlamaParse +from llama_cloud_services import LlamaCloudIndex, LlamaParse +from llama_cloud_services.parse import ResultType # deployed agents may infer their name from the deployment name # Note: Make sure that an agent deployment with this name actually exists @@ -18,7 +19,8 @@ INDEX_NAME = "document_qa_index" -def get_custom_client() -> httpx.AsyncClient: +@functools.cache +def get_base_cloud_client() -> httpx.AsyncClient: return httpx.AsyncClient( timeout=60, headers={"Project-Id": LLAMA_CLOUD_PROJECT_ID} @@ -32,7 +34,7 @@ def get_llama_cloud_client() -> AsyncLlamaCloud: return AsyncLlamaCloud( base_url=LLAMA_CLOUD_BASE_URL, token=LLAMA_CLOUD_API_KEY, - httpx_client=get_custom_client(), + httpx_client=get_base_cloud_client(), ) @@ -45,8 +47,20 @@ def get_llama_parse_client() -> LlamaParse: adaptive_long_table=True, outlined_table_extraction=True, output_tables_as_HTML=True, - result_type="markdown", + result_type=ResultType.MD, api_key=LLAMA_CLOUD_API_KEY, project_id=LLAMA_CLOUD_PROJECT_ID, - custom_client=get_custom_client(), + custom_client=get_base_cloud_client(), + ) + + +@functools.lru_cache(maxsize=None) +def get_index(index_name: str) -> LlamaCloudIndex: + return LlamaCloudIndex.create_index( + name=index_name, + project_id=LLAMA_CLOUD_PROJECT_ID, + api_key=LLAMA_CLOUD_API_KEY, + base_url=LLAMA_CLOUD_BASE_URL, + show_progress=True, + custom_client=get_base_cloud_client(), ) diff --git a/src/{{ project_name_snake }}/qa_workflows.py b/src/{{ project_name_snake }}/qa_workflows.py index a6a5878..a9a2bc4 100644 --- a/src/{{ project_name_snake }}/qa_workflows.py +++ b/src/{{ project_name_snake }}/qa_workflows.py @@ -1,16 +1,21 @@ +from __future__ import annotations +from datetime import datetime import logging import os import tempfile +from typing import Any, Literal import httpx -from dotenv import load_dotenv -from llama_cloud.types import RetrievalMode from llama_index.core import Settings -from llama_index.core.chat_engine.types import BaseChatEngine, ChatMode -from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.chat_engine.types import ( + BaseChatEngine, + ChatMode, +) +from llama_index.core.llms import ChatMessage +import asyncio from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from llama_cloud_services import LlamaCloudIndex +from pydantic import BaseModel, Field from workflows import Workflow, step, Context from workflows.events import ( StartEvent, @@ -22,17 +27,12 @@ from workflows.retry_policy import ConstantDelayRetryPolicy from .clients import ( - LLAMA_CLOUD_API_KEY, - LLAMA_CLOUD_BASE_URL, - get_custom_client, + get_index, get_llama_cloud_client, get_llama_parse_client, LLAMA_CLOUD_PROJECT_ID, ) -load_dotenv() - - logger = logging.getLogger(__name__) @@ -53,15 +53,13 @@ class FileDownloadedEvent(Event): class ChatEvent(StartEvent): index_name: str - session_id: str + conversation_history: list[ConversationMessage] = Field(default_factory=list) # Configure LLM and embedding model Settings.llm = OpenAI(model="gpt-4", temperature=0.1) Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small") -custom_client = get_custom_client() - class DocumentUploadWorkflow(Workflow): """Workflow to upload and index documents using LlamaParse and LlamaCloud Index""" @@ -131,15 +129,7 @@ async def parse_document(self, ev: FileDownloadedEvent, ctx: Context) -> StopEve documents = result.get_text_documents() # Create or connect to LlamaCloud Index - index = LlamaCloudIndex.create_index( - documents=documents, - name=index_name, - project_id=LLAMA_CLOUD_PROJECT_ID, - api_key=LLAMA_CLOUD_API_KEY, - base_url=LLAMA_CLOUD_BASE_URL, - show_progress=True, - custom_client=custom_client, - ) + index = get_index(index_name) # Insert documents to index logger.info(f"Inserting {len(documents)} documents to {index_name}") @@ -158,18 +148,14 @@ async def parse_document(self, ev: FileDownloadedEvent, ctx: Context) -> StopEve ) except Exception as e: - logger.error(e.stack_trace) - return StopEvent( - result={"success": False, "error": str(e), "stack_trace": e.stack_trace} - ) + logger.error(f"Error parsing document {ev.file_id}: {e}", exc_info=True) + return StopEvent(result={"success": False, "error": str(e)}) -class ChatResponseEvent(Event): - """Event emitted when chat engine generates a response""" +class AppendChatMessage(Event): + """Event emitted when chat engine appends a message to the conversation history""" - response: str - sources: list - query: str + message: ConversationMessage class ChatDeltaEvent(Event): @@ -178,88 +164,119 @@ class ChatDeltaEvent(Event): delta: str +class QueryConversationHistoryEvent(HumanResponseEvent): + """Client can call this to trigger replaying AppendChatMessage events""" + + pass + + +class ErrorEvent(Event): + """Event emitted when an error occurs""" + + error: str + + +class ChatWorkflowState(BaseModel): + index_name: str | None = None + conversation_history: list[ConversationMessage] = Field(default_factory=list) + + def chat_messages(self) -> list[ChatMessage]: + return [ + ChatMessage(role=message.role, content=message.text) + for message in self.conversation_history + ] + + +class SourceMessage(BaseModel): + text: str + score: float + metadata: dict[str, Any] + + +class ConversationMessage(BaseModel): + """ + Mostly just a wrapper for a ChatMessage with extra context for UI. Includes a timestamp and source references. + """ + + role: Literal["user", "assistant"] + text: str + sources: list[SourceMessage] = Field(default_factory=list) + timestamp: str = Field(default_factory=lambda: datetime.now().isoformat()) + + +def get_chat_engine(index_name: str) -> BaseChatEngine: + index = get_index(index_name) + return index.as_chat_engine( + chat_mode=ChatMode.CONTEXT, + llm=Settings.llm, + context_prompt=( + "You are a helpful assistant that answers questions based on the provided documents. " + "Always cite specific information from the documents when answering. " + "If you cannot find the answer in the documents, say so clearly." + ), + ) + + class ChatWorkflow(Workflow): """Workflow to handle continuous chat queries against indexed documents""" - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.chat_engines: dict[ - str, BaseChatEngine - ] = {} # Cache chat engines per index - @step - async def initialize_chat(self, ev: ChatEvent, ctx: Context) -> InputRequiredEvent: + async def initialize_chat( + self, ev: ChatEvent, ctx: Context[ChatWorkflowState] + ) -> InputRequiredEvent | StopEvent | None: """Initialize the chat session and request first input""" try: logger.info(f"Initializing chat {ev.index_name}") index_name = ev.index_name - session_id = ev.session_id + initial_state = await ctx.store.get_state() # Store session info in context await ctx.store.set("index_name", index_name) - await ctx.store.set("session_id", session_id) - await ctx.store.set("conversation_history", []) - - # Create cache key for chat engine - cache_key = f"{index_name}_{session_id}" - - # Initialize chat engine if not exists - if cache_key not in self.chat_engines: - logger.info(f"Initializing chat engine {cache_key}") - # Connect to LlamaCloud Index - index = LlamaCloudIndex( - name=index_name, - project_id=LLAMA_CLOUD_PROJECT_ID, - api_key=LLAMA_CLOUD_API_KEY, - base_url=LLAMA_CLOUD_BASE_URL, - async_httpx_client=custom_client, - ) + messages = initial_state.conversation_history - # Create chat engine with memory - memory = ChatMemoryBuffer.from_defaults(token_limit=3900) - self.chat_engines[cache_key] = index.as_chat_engine( - chat_mode=ChatMode.CONTEXT, - memory=memory, - llm=Settings.llm, - context_prompt=( - "You are a helpful assistant that answers questions based on the provided documents. " - "Always cite specific information from the documents when answering. " - "If you cannot find the answer in the documents, say so clearly." - ), - verbose=False, - retriever_mode=RetrievalMode.CHUNKS, - ) + for item in messages: + ctx.write_event_to_stream(AppendChatMessage(message=item)) - # Request first user input - return InputRequiredEvent( - prefix="Chat initialized. Ask a question (or type 'exit' to quit): " - ) + if ev.conversation_history: + async with ctx.store.edit_state() as state: + state.conversation_history.extend(ev.conversation_history) except Exception as e: - return StopEvent( - result={ - "success": False, - "error": f"Failed to initialize chat: {str(e)}", - } + logger.error(f"Error initializing chat: {str(e)}", exc_info=True) + ctx.write_event_to_stream( + ErrorEvent(error=f"Failed to initialize chat: {str(e)}") ) + return InputRequiredEvent() + + @step + async def get_conversation_history( + self, ev: QueryConversationHistoryEvent, ctx: Context[ChatWorkflowState] + ) -> None: + """Get the conversation history from the database""" + hist = (await ctx.store.get_state()).conversation_history + for item in hist: + ctx.write_event_to_stream(AppendChatMessage(message=item)) @step async def process_user_response( - self, ev: HumanResponseEvent, ctx: Context - ) -> InputRequiredEvent | HumanResponseEvent | StopEvent | None: + self, ev: HumanResponseEvent, ctx: Context[ChatWorkflowState] + ) -> InputRequiredEvent | HumanResponseEvent | None: """Process user input and generate response""" try: logger.info(f"Processing user response {ev.response}") user_input = ev.response.strip() + initial_state = await ctx.store.get_state() + conversation_history = initial_state.conversation_history + index_name = initial_state.index_name + if not index_name: + raise ValueError("Index name not found in context") + logger.info(f"User input: {user_input}") # Check for exit command if user_input.lower() == "exit": logger.info("User input is exit") - conversation_history = await ctx.store.get( - "conversation_history", default=[] - ) return StopEvent( result={ "success": True, @@ -268,72 +285,52 @@ async def process_user_response( } ) - # Get session info from context - index_name = await ctx.store.get("index_name") - session_id = await ctx.store.get("session_id") - cache_key = f"{index_name}_{session_id}" - - # Get chat engine - chat_engine = self.chat_engines[cache_key] + chat_engine = get_chat_engine(index_name) - # Process query with chat engine (streaming) - stream_response = await chat_engine.astream_chat(user_input) + stream_response = await chat_engine.astream_chat( + user_input, chat_history=initial_state.chat_messages() + ) full_text = "" # Emit streaming deltas to the event stream async for token in stream_response.async_response_gen(): full_text += token ctx.write_event_to_stream(ChatDeltaEvent(delta=token)) + await asyncio.sleep( + 0 + ) # Temp workaround. Some sort of bug in the server drops events without flushing the event loop # Extract source nodes for citations sources = [] - if hasattr(stream_response, "source_nodes"): + if stream_response.source_nodes: for node in stream_response.source_nodes: sources.append( - { - "text": node.text[:200] + "..." - if len(node.text) > 200 + SourceMessage( + text=node.text[:197] + "..." + if len(node.text) >= 200 else node.text, - "score": node.score if hasattr(node, "score") else None, - "metadata": node.metadata - if hasattr(node, "metadata") - else {}, - } + score=float(node.score) if node.score else 0.0, + metadata=node.metadata, + ) ) - # Update conversation history - conversation_history = await ctx.store.get( - "conversation_history", default=[] - ) - conversation_history.append( - { - "query": user_input, - "response": full_text.strip() - if full_text - else str(stream_response), - "sources": sources, - } - ) - await ctx.store.set("conversation_history", conversation_history) - # After streaming completes, emit a summary response event to stream for frontend/main printing - ctx.write_event_to_stream( - ChatResponseEvent( - response=full_text.strip() if full_text else str(stream_response), - sources=sources, - query=user_input, - ) - ) - - # Prompt for next input - return InputRequiredEvent( - prefix="\nAsk another question (or type 'exit' to quit): " + assistant_response = ConversationMessage( + role="assistant", text=full_text, sources=sources ) + ctx.write_event_to_stream(AppendChatMessage(message=assistant_response)) + async with ctx.store.edit_state() as state: + state.conversation_history.extend( + [ + ConversationMessage(role="user", text=user_input), + assistant_response, + ] + ) except Exception as e: - return StopEvent( - result={"success": False, "error": f"Error processing query: {str(e)}"} - ) + logger.error(f"Error processing query: {str(e)}", exc_info=True) + ctx.write_event_to_stream(ErrorEvent(error=str(e))) + return InputRequiredEvent() upload = DocumentUploadWorkflow(timeout=None) diff --git a/test-proj/pyproject.toml b/test-proj/pyproject.toml index f7706af..937a683 100644 --- a/test-proj/pyproject.toml +++ b/test-proj/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "test_proj" +name = "test-proj" version = "0.1.0" description = "Add your description here" readme = "README.md" @@ -7,7 +7,7 @@ authors = [] requires-python = ">=3.12" dependencies = [ "llama-index-workflows>=2.2.0,<3.0.0", - "llama-cloud-services>=0.6.68", + "llama-cloud-services>=0.6.69", "llama-index-core>=0.14.0", "llama-index-llms-openai>=0.5.6", "llama-index-embeddings-openai>=0.5.1", diff --git a/test-proj/src/test_proj/clients.py b/test-proj/src/test_proj/clients.py index 9f8d0d8..ee67170 100644 --- a/test-proj/src/test_proj/clients.py +++ b/test-proj/src/test_proj/clients.py @@ -3,7 +3,8 @@ import httpx from llama_cloud.client import AsyncLlamaCloud -from llama_cloud_services import LlamaParse +from llama_cloud_services import LlamaCloudIndex, LlamaParse +from llama_cloud_services.parse import ResultType # deployed agents may infer their name from the deployment name # Note: Make sure that an agent deployment with this name actually exists @@ -18,7 +19,8 @@ INDEX_NAME = "document_qa_index" -def get_custom_client() -> httpx.AsyncClient: +@functools.cache +def get_base_cloud_client() -> httpx.AsyncClient: return httpx.AsyncClient( timeout=60, headers={"Project-Id": LLAMA_CLOUD_PROJECT_ID} @@ -32,7 +34,7 @@ def get_llama_cloud_client() -> AsyncLlamaCloud: return AsyncLlamaCloud( base_url=LLAMA_CLOUD_BASE_URL, token=LLAMA_CLOUD_API_KEY, - httpx_client=get_custom_client(), + httpx_client=get_base_cloud_client(), ) @@ -45,8 +47,20 @@ def get_llama_parse_client() -> LlamaParse: adaptive_long_table=True, outlined_table_extraction=True, output_tables_as_HTML=True, - result_type="markdown", + result_type=ResultType.MD, api_key=LLAMA_CLOUD_API_KEY, project_id=LLAMA_CLOUD_PROJECT_ID, - custom_client=get_custom_client(), + custom_client=get_base_cloud_client(), + ) + + +@functools.lru_cache(maxsize=None) +def get_index(index_name: str) -> LlamaCloudIndex: + return LlamaCloudIndex.create_index( + name=index_name, + project_id=LLAMA_CLOUD_PROJECT_ID, + api_key=LLAMA_CLOUD_API_KEY, + base_url=LLAMA_CLOUD_BASE_URL, + show_progress=True, + custom_client=get_base_cloud_client(), ) diff --git a/test-proj/src/test_proj/qa_workflows.py b/test-proj/src/test_proj/qa_workflows.py index a6a5878..a9a2bc4 100644 --- a/test-proj/src/test_proj/qa_workflows.py +++ b/test-proj/src/test_proj/qa_workflows.py @@ -1,16 +1,21 @@ +from __future__ import annotations +from datetime import datetime import logging import os import tempfile +from typing import Any, Literal import httpx -from dotenv import load_dotenv -from llama_cloud.types import RetrievalMode from llama_index.core import Settings -from llama_index.core.chat_engine.types import BaseChatEngine, ChatMode -from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.chat_engine.types import ( + BaseChatEngine, + ChatMode, +) +from llama_index.core.llms import ChatMessage +import asyncio from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from llama_cloud_services import LlamaCloudIndex +from pydantic import BaseModel, Field from workflows import Workflow, step, Context from workflows.events import ( StartEvent, @@ -22,17 +27,12 @@ from workflows.retry_policy import ConstantDelayRetryPolicy from .clients import ( - LLAMA_CLOUD_API_KEY, - LLAMA_CLOUD_BASE_URL, - get_custom_client, + get_index, get_llama_cloud_client, get_llama_parse_client, LLAMA_CLOUD_PROJECT_ID, ) -load_dotenv() - - logger = logging.getLogger(__name__) @@ -53,15 +53,13 @@ class FileDownloadedEvent(Event): class ChatEvent(StartEvent): index_name: str - session_id: str + conversation_history: list[ConversationMessage] = Field(default_factory=list) # Configure LLM and embedding model Settings.llm = OpenAI(model="gpt-4", temperature=0.1) Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small") -custom_client = get_custom_client() - class DocumentUploadWorkflow(Workflow): """Workflow to upload and index documents using LlamaParse and LlamaCloud Index""" @@ -131,15 +129,7 @@ async def parse_document(self, ev: FileDownloadedEvent, ctx: Context) -> StopEve documents = result.get_text_documents() # Create or connect to LlamaCloud Index - index = LlamaCloudIndex.create_index( - documents=documents, - name=index_name, - project_id=LLAMA_CLOUD_PROJECT_ID, - api_key=LLAMA_CLOUD_API_KEY, - base_url=LLAMA_CLOUD_BASE_URL, - show_progress=True, - custom_client=custom_client, - ) + index = get_index(index_name) # Insert documents to index logger.info(f"Inserting {len(documents)} documents to {index_name}") @@ -158,18 +148,14 @@ async def parse_document(self, ev: FileDownloadedEvent, ctx: Context) -> StopEve ) except Exception as e: - logger.error(e.stack_trace) - return StopEvent( - result={"success": False, "error": str(e), "stack_trace": e.stack_trace} - ) + logger.error(f"Error parsing document {ev.file_id}: {e}", exc_info=True) + return StopEvent(result={"success": False, "error": str(e)}) -class ChatResponseEvent(Event): - """Event emitted when chat engine generates a response""" +class AppendChatMessage(Event): + """Event emitted when chat engine appends a message to the conversation history""" - response: str - sources: list - query: str + message: ConversationMessage class ChatDeltaEvent(Event): @@ -178,88 +164,119 @@ class ChatDeltaEvent(Event): delta: str +class QueryConversationHistoryEvent(HumanResponseEvent): + """Client can call this to trigger replaying AppendChatMessage events""" + + pass + + +class ErrorEvent(Event): + """Event emitted when an error occurs""" + + error: str + + +class ChatWorkflowState(BaseModel): + index_name: str | None = None + conversation_history: list[ConversationMessage] = Field(default_factory=list) + + def chat_messages(self) -> list[ChatMessage]: + return [ + ChatMessage(role=message.role, content=message.text) + for message in self.conversation_history + ] + + +class SourceMessage(BaseModel): + text: str + score: float + metadata: dict[str, Any] + + +class ConversationMessage(BaseModel): + """ + Mostly just a wrapper for a ChatMessage with extra context for UI. Includes a timestamp and source references. + """ + + role: Literal["user", "assistant"] + text: str + sources: list[SourceMessage] = Field(default_factory=list) + timestamp: str = Field(default_factory=lambda: datetime.now().isoformat()) + + +def get_chat_engine(index_name: str) -> BaseChatEngine: + index = get_index(index_name) + return index.as_chat_engine( + chat_mode=ChatMode.CONTEXT, + llm=Settings.llm, + context_prompt=( + "You are a helpful assistant that answers questions based on the provided documents. " + "Always cite specific information from the documents when answering. " + "If you cannot find the answer in the documents, say so clearly." + ), + ) + + class ChatWorkflow(Workflow): """Workflow to handle continuous chat queries against indexed documents""" - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.chat_engines: dict[ - str, BaseChatEngine - ] = {} # Cache chat engines per index - @step - async def initialize_chat(self, ev: ChatEvent, ctx: Context) -> InputRequiredEvent: + async def initialize_chat( + self, ev: ChatEvent, ctx: Context[ChatWorkflowState] + ) -> InputRequiredEvent | StopEvent | None: """Initialize the chat session and request first input""" try: logger.info(f"Initializing chat {ev.index_name}") index_name = ev.index_name - session_id = ev.session_id + initial_state = await ctx.store.get_state() # Store session info in context await ctx.store.set("index_name", index_name) - await ctx.store.set("session_id", session_id) - await ctx.store.set("conversation_history", []) - - # Create cache key for chat engine - cache_key = f"{index_name}_{session_id}" - - # Initialize chat engine if not exists - if cache_key not in self.chat_engines: - logger.info(f"Initializing chat engine {cache_key}") - # Connect to LlamaCloud Index - index = LlamaCloudIndex( - name=index_name, - project_id=LLAMA_CLOUD_PROJECT_ID, - api_key=LLAMA_CLOUD_API_KEY, - base_url=LLAMA_CLOUD_BASE_URL, - async_httpx_client=custom_client, - ) + messages = initial_state.conversation_history - # Create chat engine with memory - memory = ChatMemoryBuffer.from_defaults(token_limit=3900) - self.chat_engines[cache_key] = index.as_chat_engine( - chat_mode=ChatMode.CONTEXT, - memory=memory, - llm=Settings.llm, - context_prompt=( - "You are a helpful assistant that answers questions based on the provided documents. " - "Always cite specific information from the documents when answering. " - "If you cannot find the answer in the documents, say so clearly." - ), - verbose=False, - retriever_mode=RetrievalMode.CHUNKS, - ) + for item in messages: + ctx.write_event_to_stream(AppendChatMessage(message=item)) - # Request first user input - return InputRequiredEvent( - prefix="Chat initialized. Ask a question (or type 'exit' to quit): " - ) + if ev.conversation_history: + async with ctx.store.edit_state() as state: + state.conversation_history.extend(ev.conversation_history) except Exception as e: - return StopEvent( - result={ - "success": False, - "error": f"Failed to initialize chat: {str(e)}", - } + logger.error(f"Error initializing chat: {str(e)}", exc_info=True) + ctx.write_event_to_stream( + ErrorEvent(error=f"Failed to initialize chat: {str(e)}") ) + return InputRequiredEvent() + + @step + async def get_conversation_history( + self, ev: QueryConversationHistoryEvent, ctx: Context[ChatWorkflowState] + ) -> None: + """Get the conversation history from the database""" + hist = (await ctx.store.get_state()).conversation_history + for item in hist: + ctx.write_event_to_stream(AppendChatMessage(message=item)) @step async def process_user_response( - self, ev: HumanResponseEvent, ctx: Context - ) -> InputRequiredEvent | HumanResponseEvent | StopEvent | None: + self, ev: HumanResponseEvent, ctx: Context[ChatWorkflowState] + ) -> InputRequiredEvent | HumanResponseEvent | None: """Process user input and generate response""" try: logger.info(f"Processing user response {ev.response}") user_input = ev.response.strip() + initial_state = await ctx.store.get_state() + conversation_history = initial_state.conversation_history + index_name = initial_state.index_name + if not index_name: + raise ValueError("Index name not found in context") + logger.info(f"User input: {user_input}") # Check for exit command if user_input.lower() == "exit": logger.info("User input is exit") - conversation_history = await ctx.store.get( - "conversation_history", default=[] - ) return StopEvent( result={ "success": True, @@ -268,72 +285,52 @@ async def process_user_response( } ) - # Get session info from context - index_name = await ctx.store.get("index_name") - session_id = await ctx.store.get("session_id") - cache_key = f"{index_name}_{session_id}" - - # Get chat engine - chat_engine = self.chat_engines[cache_key] + chat_engine = get_chat_engine(index_name) - # Process query with chat engine (streaming) - stream_response = await chat_engine.astream_chat(user_input) + stream_response = await chat_engine.astream_chat( + user_input, chat_history=initial_state.chat_messages() + ) full_text = "" # Emit streaming deltas to the event stream async for token in stream_response.async_response_gen(): full_text += token ctx.write_event_to_stream(ChatDeltaEvent(delta=token)) + await asyncio.sleep( + 0 + ) # Temp workaround. Some sort of bug in the server drops events without flushing the event loop # Extract source nodes for citations sources = [] - if hasattr(stream_response, "source_nodes"): + if stream_response.source_nodes: for node in stream_response.source_nodes: sources.append( - { - "text": node.text[:200] + "..." - if len(node.text) > 200 + SourceMessage( + text=node.text[:197] + "..." + if len(node.text) >= 200 else node.text, - "score": node.score if hasattr(node, "score") else None, - "metadata": node.metadata - if hasattr(node, "metadata") - else {}, - } + score=float(node.score) if node.score else 0.0, + metadata=node.metadata, + ) ) - # Update conversation history - conversation_history = await ctx.store.get( - "conversation_history", default=[] - ) - conversation_history.append( - { - "query": user_input, - "response": full_text.strip() - if full_text - else str(stream_response), - "sources": sources, - } - ) - await ctx.store.set("conversation_history", conversation_history) - # After streaming completes, emit a summary response event to stream for frontend/main printing - ctx.write_event_to_stream( - ChatResponseEvent( - response=full_text.strip() if full_text else str(stream_response), - sources=sources, - query=user_input, - ) - ) - - # Prompt for next input - return InputRequiredEvent( - prefix="\nAsk another question (or type 'exit' to quit): " + assistant_response = ConversationMessage( + role="assistant", text=full_text, sources=sources ) + ctx.write_event_to_stream(AppendChatMessage(message=assistant_response)) + async with ctx.store.edit_state() as state: + state.conversation_history.extend( + [ + ConversationMessage(role="user", text=user_input), + assistant_response, + ] + ) except Exception as e: - return StopEvent( - result={"success": False, "error": f"Error processing query: {str(e)}"} - ) + logger.error(f"Error processing query: {str(e)}", exc_info=True) + ctx.write_event_to_stream(ErrorEvent(error=str(e))) + return InputRequiredEvent() upload = DocumentUploadWorkflow(timeout=None) diff --git a/test-proj/ui/index.html b/test-proj/ui/index.html index 37e7c42..ce3774b 100644 --- a/test-proj/ui/index.html +++ b/test-proj/ui/index.html @@ -4,6 +4,12 @@ Quick Start UI +
diff --git a/test-proj/ui/package.json b/test-proj/ui/package.json index 8cbfa1e..514717c 100644 --- a/test-proj/ui/package.json +++ b/test-proj/ui/package.json @@ -17,6 +17,7 @@ "@llamaindex/ui": "^2.1.1", "@llamaindex/workflows-client": "^1.2.0", "@radix-ui/themes": "^3.2.1", + "idb": "^8.0.3", "llama-cloud-services": "^0.3.6", "lucide-react": "^0.544.0", "react": "^19.0.0", diff --git a/test-proj/ui/src/App.tsx b/test-proj/ui/src/App.tsx index 6658701..55dffe3 100644 --- a/test-proj/ui/src/App.tsx +++ b/test-proj/ui/src/App.tsx @@ -2,8 +2,29 @@ import { ApiProvider } from "@llamaindex/ui"; import Home from "./pages/Home"; import { Theme } from "@radix-ui/themes"; import { clients } from "@/libs/clients"; +import { useEffect } from "react"; export default function App() { + // Apply dark mode based on system preference + useEffect(() => { + const mediaQuery = window.matchMedia("(prefers-color-scheme: dark)"); + + const updateDarkMode = (e: MediaQueryListEvent | MediaQueryList) => { + if (e.matches) { + document.documentElement.classList.add("dark"); + } else { + document.documentElement.classList.remove("dark"); + } + }; + + // Set initial state + updateDarkMode(mediaQuery); + + // Listen for changes + mediaQuery.addEventListener("change", updateDarkMode); + + return () => mediaQuery.removeEventListener("change", updateDarkMode); + }, []); return ( diff --git a/test-proj/ui/src/components/ChatBot.tsx b/test-proj/ui/src/components/ChatBot.tsx index e65f929..3256e4f 100644 --- a/test-proj/ui/src/components/ChatBot.tsx +++ b/test-proj/ui/src/components/ChatBot.tsx @@ -1,167 +1,30 @@ // This is a temporary chatbot component that is used to test the chatbot functionality. // LlamaIndex will replace it with better chatbot component. -import { useState, useRef, useEffect, FormEvent, KeyboardEvent } from "react"; -import { - Send, - Loader2, - Bot, - User, - MessageSquare, - Trash2, - RefreshCw, -} from "lucide-react"; -import { - Button, - Input, - ScrollArea, - Card, - CardContent, - cn, - useWorkflowRun, - useWorkflowHandler, -} from "@llamaindex/ui"; -import { AGENT_NAME } from "../libs/config"; -import { toHumanResponseRawEvent } from "@/libs/utils"; - -type Role = "user" | "assistant"; -interface Message { - id: string; - role: Role; - content: string; - timestamp: Date; - error?: boolean; -} -export default function ChatBot() { - const { runWorkflow } = useWorkflowRun(); +import { useChatbot } from "@/libs/useChatbot"; +import { Button, cn, ScrollArea, Textarea } from "@llamaindex/ui"; +import { Bot, Loader2, RefreshCw, Send, User } from "lucide-react"; +import { FormEvent, KeyboardEvent, useEffect, useRef } from "react"; + +export default function ChatBot({ + handlerId, + onHandlerCreated, +}: { + handlerId?: string; + onHandlerCreated?: (handlerId: string) => void; +}) { + const inputRef = useRef(null); const messagesEndRef = useRef(null); - const inputRef = useRef(null); - const [messages, setMessages] = useState([]); - const [input, setInput] = useState(""); - const [isLoading, setIsLoading] = useState(false); - const [handlerId, setHandlerId] = useState(null); - const lastProcessedEventIndexRef = useRef(0); - const [canSend, setCanSend] = useState(false); - const streamingMessageIndexRef = useRef(null); - - // Deployment + auth setup - const deployment = AGENT_NAME || "document-qa"; - const platformToken = (import.meta as any).env?.VITE_LLAMA_CLOUD_API_KEY as - | string - | undefined; - const projectId = (import.meta as any).env?.VITE_LLAMA_DEPLOY_PROJECT_ID as - | string - | undefined; - const defaultIndexName = - (import.meta as any).env?.VITE_DEFAULT_INDEX_NAME || "document_qa_index"; - const sessionIdRef = useRef( - `chat-${Math.random().toString(36).slice(2)}-${Date.now()}`, - ); + const chatbot = useChatbot({ + handlerId, + onHandlerCreated, + focusInput: () => { + inputRef.current?.focus(); + }, + }); // UI text defaults const title = "AI Document Assistant"; const placeholder = "Ask me anything about your documents..."; - const welcomeMessage = - "Welcome! 👋 Upload a document with the control above, then ask questions here."; - - // Helper functions for message management - const appendMessage = (role: Role, msg: string): void => { - setMessages((prev) => { - const id = `${role}-stream-${Date.now()}`; - const idx = prev.length; - streamingMessageIndexRef.current = idx; - return [ - ...prev, - { - id, - role, - content: msg, - timestamp: new Date(), - }, - ]; - }); - }; - - const updateMessage = (index: number, message: string) => { - setMessages((prev) => { - if (index < 0 || index >= prev.length) return prev; - const copy = [...prev]; - const existing = copy[index]; - copy[index] = { ...existing, content: message }; - return copy; - }); - }; - - // Initialize with welcome message - useEffect(() => { - if (messages.length === 0) { - const welcomeMsg: Message = { - id: "welcome", - role: "assistant", - content: welcomeMessage, - timestamp: new Date(), - }; - setMessages([welcomeMsg]); - } - }, []); - - // Create chat task on init - useEffect(() => { - (async () => { - if (!handlerId) { - const handler = await runWorkflow("chat", { - index_name: defaultIndexName, - session_id: sessionIdRef.current, - }); - setHandlerId(handler.handler_id); - } - })(); - }, []); - - // Subscribe to task/events using hook (auto stream when handler exists) - const { events } = useWorkflowHandler(handlerId ?? "", Boolean(handlerId)); - - // Process streamed events into messages - useEffect(() => { - if (!events || events.length === 0) return; - let startIdx = lastProcessedEventIndexRef.current; - if (startIdx < 0) startIdx = 0; - if (startIdx >= events.length) return; - - for (let i = startIdx; i < events.length; i++) { - const ev: any = events[i]; - const type = ev?.type as string | undefined; - const rawData = ev?.data as any; - if (!type) continue; - const data = (rawData && (rawData._data ?? rawData)) as any; - - if (type.includes("ChatDeltaEvent")) { - const delta: string = data?.delta ?? ""; - if (!delta) continue; - if (streamingMessageIndexRef.current === null) { - appendMessage("assistant", delta); - } else { - const idx = streamingMessageIndexRef.current; - const current = messages[idx!]?.content ?? ""; - if (current === "Thinking...") { - updateMessage(idx!, delta); - } else { - updateMessage(idx!, current + delta); - } - } - } else if (type.includes("ChatResponseEvent")) { - // finalize current stream - streamingMessageIndexRef.current = null; - } else if (type.includes("InputRequiredEvent")) { - // ready for next user input; enable send - setCanSend(true); - setIsLoading(false); - inputRef.current?.focus(); - } else if (type.includes("StopEvent")) { - // finished; no summary bubble needed (chat response already streamed) - } - } - lastProcessedEventIndexRef.current = events.length; - }, [events, messages]); const scrollToBottom = () => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); @@ -169,324 +32,204 @@ export default function ChatBot() { useEffect(() => { scrollToBottom(); - }, [messages]); - - // No manual SSE cleanup needed + }, [chatbot.messages]); - const getCommonHeaders = () => ({ - ...(platformToken ? { authorization: `Bearer ${platformToken}` } : {}), - ...(projectId ? { "Project-Id": projectId } : {}), - }); - - const startChatIfNeeded = async (): Promise => { - if (handlerId) return handlerId; - const handler = await runWorkflow("chat", { - index_name: defaultIndexName, - session_id: sessionIdRef.current, - }); - setHandlerId(handler.handler_id); - return handler.handler_id; - }; - - // Removed manual SSE ensureEventStream; hook handles streaming + // Reset textarea height when input is cleared + useEffect(() => { + if (!chatbot.input && inputRef.current) { + inputRef.current.style.height = "48px"; // Reset to initial height + } + }, [chatbot.input]); const handleSubmit = async (e: FormEvent) => { e.preventDefault(); - - const trimmedInput = input.trim(); - if (!trimmedInput || isLoading || !canSend) return; - - // Add user message - const userMessage: Message = { - id: `user-${Date.now()}`, - role: "user", - content: trimmedInput, - timestamp: new Date(), - }; - - const newMessages = [...messages, userMessage]; - setMessages(newMessages); - setInput(""); - setIsLoading(true); - setCanSend(false); - - // Immediately create an assistant placeholder to avoid visual gap before deltas - if (streamingMessageIndexRef.current === null) { - appendMessage("assistant", "Thinking..."); - } - - try { - // Ensure chat handler exists (created on init) - const hid = await startChatIfNeeded(); - - // Send user input as HumanResponseEvent - const postRes = await fetch(`/deployments/${deployment}/events/${hid}`, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...getCommonHeaders(), - }, - body: JSON.stringify({ - event: JSON.stringify(toHumanResponseRawEvent(trimmedInput)), - }), - }); - if (!postRes.ok) { - throw new Error( - `Failed to send message: ${postRes.status} ${postRes.statusText}`, - ); - } - - // The assistant reply will be streamed by useWorkflowTask and appended incrementally - } catch (err) { - console.error("Chat error:", err); - - // Add error message - const errorMessage: Message = { - id: `error-${Date.now()}`, - role: "assistant", - content: `Sorry, I encountered an error: ${err instanceof Error ? err.message : "Unknown error"}. Please try again.`, - timestamp: new Date(), - error: true, - }; - - setMessages((prev) => [...prev, errorMessage]); - } finally { - setIsLoading(false); - // Focus back on input - inputRef.current?.focus(); - } + await chatbot.submit(); }; - const handleKeyDown = (e: KeyboardEvent) => { + const handleKeyDown = (e: KeyboardEvent) => { // Submit on Enter (without Shift) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); handleSubmit(e as any); } + // Allow Shift+Enter to create new line (default behavior) }; - const clearChat = () => { - setMessages([ - { - id: "welcome", - role: "assistant" as const, - content: welcomeMessage, - timestamp: new Date(), - }, - ]); - setInput(""); - inputRef.current?.focus(); + const adjustTextareaHeight = (textarea: HTMLTextAreaElement) => { + textarea.style.height = "auto"; + textarea.style.height = Math.min(textarea.scrollHeight, 128) + "px"; // 128px = max-h-32 }; - const retryLastMessage = () => { - const lastUserMessage = messages.filter((m) => m.role === "user").pop(); - if (lastUserMessage) { - // Remove the last assistant message if it was an error - const lastMessage = messages[messages.length - 1]; - if (lastMessage.role === "assistant" && lastMessage.error) { - setMessages((prev) => prev.slice(0, -1)); - } - setInput(lastUserMessage.content); - inputRef.current?.focus(); - } + const handleInputChange = (e: React.ChangeEvent) => { + chatbot.setInput(e.target.value); + adjustTextareaHeight(e.target); }; return ( -
- {/* Header */} -
-
-
- -

- {title} -

- {isLoading && ( - - Thinking... - - )} -
-
- {messages.some((m) => m.error) && ( - - )} - {messages.length > 0 && ( - - )} +
+ {/* Simplified header - only show retry button when needed */} + {chatbot.messages.some((m) => m.error) && ( +
+
+
-
+ )} {/* Messages */} - - {messages.length === 0 ? ( -
-
- -

- No messages yet -

-

- Start a conversation! -

+ +
+ {chatbot.messages.length === 0 ? ( +
+
+ +

+ Welcome! 👋 Upload a document with the control above, then ask + questions here. +

+

+ Start by uploading a document to begin your conversation +

+
-
- ) : ( -
- {messages.map((message) => ( -
- {message.role !== "user" && ( -
- -
- )} + ) : ( +
+ {chatbot.messages.map((message, i) => (
- + +
+ )} +
- -

- {message.content} -

+
+ {message.isPartial && !message.content ? ( +
+ +
+ ) : ( +

+ {message.content} +

+ )}

{message.timestamp.toLocaleTimeString()}

- - -
- {message.role === "user" && ( -
- +
- )} -
- ))} - - {isLoading && ( -
-
- -
- - -
-
- - - -
+ {message.role === "user" && ( +
+
- - -
- )} -
-
- )} + )} +
+ ))} +
+
+ )} +
{/* Input */} -
-
- setInput(e.target.value)} - onKeyDown={handleKeyDown} - placeholder={placeholder} - disabled={isLoading} - className="flex-1" - autoFocus - /> - -
-

- Press Enter to send • Shift+Enter for new line -

+
+
+
+