Skip to content

Commit

Permalink
credentials check and improve context save/read from ttl cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugo Saporetti Junior committed Aug 1, 2024
1 parent 4b41ec8 commit d93db97
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 44 deletions.
11 changes: 7 additions & 4 deletions src/main/askai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Copyright (c) 2024, HomeSetup
"""
import atexit
import logging as log
import os
import sys
Expand All @@ -32,6 +33,7 @@
from askai.__classpath__ import classpath
from askai.core.askai import AskAi
from askai.core.askai_configs import configs
from askai.core.support.shared_instances import shared
from askai.tui.askai_app import AskAiApp

if not is_a_tty():
Expand Down Expand Up @@ -83,14 +85,14 @@ def _setup_arguments(self) -> None:
"debug", "d", "debug",
"whether you want to run under debug mode.",
nargs="?", action=ParserAction.STORE_TRUE) \
.option(
"cache", "c", "cache",
"whether you want to cache AI replies.",
nargs="?", action=ParserAction.STORE_TRUE) \
.option(
"ui", "u", "ui",
"whether to use the new AskAI TUI (experimental).",
nargs="?", action=ParserAction.STORE_TRUE)\
.option(
"cache", "c", "cache",
"whether you want to cache AI replies.",
nargs="?", default=True) \
.option(
"tempo", "t", "tempo",
"specifies the playback and streaming speed.",
Expand Down Expand Up @@ -157,6 +159,7 @@ def _main(self, *params, **kwargs) -> ExitStatus:
def _exec_application(self) -> ExitStatus:
"""Execute the application main flow."""
self._askai.run()
shared.context.save()

return ExitStatus.SUCCESS

Expand Down
56 changes: 36 additions & 20 deletions src/main/askai/core/component/cache_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
Copyright (c) 2024, HomeSetup
"""
from askai.core.askai_configs import configs
from askai.core.askai_settings import ASKAI_DIR
from clitt.core.tui.line_input.keyboard_input import KeyboardInput
import re
from collections import namedtuple
from pathlib import Path
from typing import Optional, Tuple

from clitt.core.tui.line_input.keyboard_input import KeyboardInput
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.tools.commons import file_is_not_empty
from hspylib.core.tools.text_tools import hash_text
from hspylib.modules.cache.ttl_cache import TTLCache
from pathlib import Path
from typing import Optional, Tuple

from askai.core.askai_configs import configs
from askai.core.askai_settings import ASKAI_DIR

# AskAI cache root directory.
CACHE_DIR: Path = Path(f"{ASKAI_DIR}/cache")
Expand Down Expand Up @@ -80,9 +83,13 @@ class CacheService(metaclass=Singleton):

INSTANCE: "CacheService"

ASKAI_CACHE_KEYS = "askai-cache-keys"
ASKAI_CACHE_KEYS: str = "askai-cache-keys"

ASKAI_INPUT_CACHE_KEY: str = "askai-input-history"

ASKAI_CONTEXT_KEY: str = "askai-context-key"

ASKAI_INPUT_CACHE_KEY = "askai-input-history"
# ASKAI_CONTEXT

_TTL_CACHE: TTLCache[str] = TTLCache(ttl_minutes=configs.ttl)

Expand All @@ -104,6 +111,19 @@ def audio_file_path(self, text: str, voice: str = "onyx", audio_format: str = "m
audio_file_path = f"{str(AUDIO_DIR)}/askai-{hash_text(key)}.{audio_format}"
return audio_file_path, file_is_not_empty(audio_file_path)

def save_reply(self, text: str, reply: str) -> Optional[str]:
"""Save a AI reply into the TTL cache.
:param text: Text to be cached.
:param reply: The reply associated to this text.
"""
if configs.is_cache:
key = text.strip().lower()
self._TTL_CACHE.save(key, reply)
self.keys.add(key)
self._TTL_CACHE.save(self.ASKAI_CACHE_KEYS, ",".join(self._cache_keys))
return text
return None

def read_reply(self, text: str) -> Optional[str]:
"""Read AI replies from TTL cache.
:param text: Text to be cached.
Expand All @@ -125,19 +145,6 @@ def del_reply(self, text: str) -> Optional[str]:
return text
return None

def save_reply(self, text: str, reply: str) -> Optional[str]:
"""Save a AI reply into the TTL cache.
:param text: Text to be cached.
:param reply: The reply associated to this text.
"""
if configs.is_cache:
key = text.strip().lower()
self._TTL_CACHE.save(key, reply)
self.keys.add(key)
self._TTL_CACHE.save(self.ASKAI_CACHE_KEYS, ",".join(self._cache_keys))
return text
return None

def clear_replies(self) -> list[str]:
"""Clear all cached replies."""
return list(map(self.del_reply, sorted(self.keys)))
Expand All @@ -158,5 +165,14 @@ def load_input_history(self, predefined: list[str] = None) -> list[str]:
history.extend(list(filter(lambda c: c not in history, predefined)))
return history

def save_context(self, context: list[str]) -> None:
"""Save the Context window entries into the TTL cache."""
self._TTL_CACHE.save(self.ASKAI_CONTEXT_KEY, "%EOL%".join(context))

def read_context(self) -> list[str]:
"""Read the Context window entries from the TTL cache."""
ctx_str: str = self._TTL_CACHE.read(self.ASKAI_CONTEXT_KEY)
return re.split(r'%EOL%', ctx_str, flags=re.MULTILINE | re.IGNORECASE) if ctx_str else []


assert (cache := CacheService().INSTANCE) is not None
2 changes: 1 addition & 1 deletion src/main/askai/core/enums/acc_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def matches(cls, output: str) -> re.Match:

@classmethod
def _re(cls) -> str:
return rf"^\$?({'|'.join(cls.values())})[:,-]\s*(.+)"
return rf"^\$?({'|'.join(cls.values())})[:,-]\s*[0-9]+%\s+(.+)"

@classmethod
def strip_code(cls, message: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/main/askai/core/features/router/procs/task_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self):
def template(self) -> ChatPromptTemplate:
"""Retrieve the processor Template."""

rag: str = str(shared.context.flat("SCRATCHPAD"))
rag: str = str(shared.context.flat("EVALUATION"))
template = PromptTemplate(
input_variables=["os_type", "shell", "datetime", "home"], template=prompt.read_prompt("task-split.txt")
)
Expand All @@ -88,7 +88,7 @@ def process(self, question: str, **_) -> Optional[str]:
"""

os.chdir(Path.home())
shared.context.forget("SCRATCHPAD") # Erase previous scratchpad.
shared.context.forget("EVALUATION") # Erase previous scratchpad.
model: ModelResult = ModelResult.default() # Hard-coding the result model for now.

@retry(exceptions=self.RETRIABLE_ERRORS, tries=configs.max_router_retries, backoff=0)
Expand Down
4 changes: 3 additions & 1 deletion src/main/askai/core/features/router/task_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from askai.core.features.router.tools.summarization import summarize
from askai.core.features.router.tools.terminal import execute_command, list_contents, open_command
from askai.core.features.router.tools.vision import image_captioner
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import TerminatingQuery
from clitt.core.tui.line_input.line_input import line_input
from functools import lru_cache
Expand Down Expand Up @@ -129,7 +130,8 @@ def direct_answer(self, question: str, answer: str) -> str:
:param question: The original user question.
:param answer: Your direct answer to the user.
"""
return final_answer(question=question, response=answer, persona_prompt="taius-jarvis")
args = {'user': shared.username, 'idiom': shared.idiom, 'context': answer, 'question': question}
return final_answer("taius-jarvis", [k for k in args.keys()], **args)

def list_tool(self, folder: str, filters: str | None = None) -> str:
"""
Expand Down
1 change: 0 additions & 1 deletion src/main/askai/core/features/router/tools/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def final_answer(
**prompt_args
) -> str:
"""Provide the final response to the user.
:param temperature: The LLM temperature to be used.
:param persona_prompt: The persona prompt to be used.
:param input_variables: The prompt input variables.
:param prompt_args: The prompt input arguments.
Expand Down
18 changes: 9 additions & 9 deletions src/main/askai/core/features/validation/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@

import logging as log

ACC_GUIDELINES: str = dedent(
EVALUATION_GUIDE: str = dedent(
"""
**Performance Evaluation Guidelines**
**Accuracy Evaluation Guidelines:**
1. Continuously review and analyze your actions to ensure optimal performance.
2. Constructively self-criticize your overall behavior regularly.
1. Review and analyze your past responses to ensure optimal accuracy.
2. Constructively self-criticize your overall responses regularly.
3. Reflect on past decisions and strategies to refine your approach.
4. Try something different.
"""
).strip()

Expand Down Expand Up @@ -69,13 +70,12 @@ def assert_accuracy(
events.reply.emit(message=msg.assert_acc(status, details), verbosity="debug")
if not (rag_resp := AccResponse.of_status(status, details)).passed(pass_threshold):
# Include the guidelines for the first mistake.
if not shared.context.get("SCRATCHPAD"):
shared.context.push("SCRATCHPAD", ACC_GUIDELINES)
# Include the RYG issues.
shared.context.push("SCRATCHPAD", issues_prompt.format(problems=AccResponse.strip_code(output)))
if not shared.context.get("EVALUATION"):
shared.context.push("EVALUATION", EVALUATION_GUIDE)
shared.context.push("EVALUATION", issues_prompt.format(problems=AccResponse.strip_code(output)))
raise InaccurateResponse(f"AI Assistant failed to respond => '{response.content}'")
return rag_resp
# At this point, the response was not Good enough.
# At this point, the response was not Good.
raise InaccurateResponse(f"AI Assistant didn't respond accurately. Response: '{response}'")

events.reply.emit(message=msg.no_output("query"))
Expand Down
23 changes: 20 additions & 3 deletions src/main/askai/core/support/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
Copyright (c) 2024, HomeSetup
"""
import re

from askai.core.component.cache_service import cache
from askai.exception.exceptions import TokenLengthExceeded
from collections import defaultdict, deque, namedtuple
from functools import partial, reduce
Expand All @@ -36,10 +38,19 @@ class ChatContext:

LANGCHAIN_ROLE_MAP: dict = {"human": HumanMessage, "system": SystemMessage, "assistant": AIMessage}

def __init__(self, token_limit: int, _max_context_size: int):
self._store: dict[AnyStr, deque] = defaultdict(partial(deque, maxlen=_max_context_size))
@staticmethod
def of(context: list[str], token_limit: int, max_context_size: int) -> 'ChatContext':
"""Create a chat context from a context list on the format: <ROLE: MSG>"""
ctx = ChatContext(token_limit, max_context_size)
for e in context:
role, reply = list(filter(None, re.split(r'(human|assistant|system):', e, flags=re.MULTILINE | re.IGNORECASE)))
ctx.push("HISTORY", reply, role)
return ctx

def __init__(self, token_limit: int, max_context_size: int):
self._store: dict[AnyStr, deque] = defaultdict(partial(deque, maxlen=max_context_size))
self._token_limit: int = token_limit * 1024 # The limit is given in KB
self._max_context_size: int = _max_context_size
self._max_context_size: int = max_context_size

def __str__(self):
ln: str = os.linesep
Expand Down Expand Up @@ -138,3 +149,9 @@ def forget(self, *keys: str) -> None:
def size(self, key: str) -> int:
"""Return the amount of entries a context specified by key have."""
return len(self._store[key])

def save(self) -> None:
"""Save the current context window."""
ctx: LangChainContext = self.join(*self._store.keys())
ctx_str: list[str] = [f"{role}: {msg}" for role, msg in ctx]
cache.save_context(ctx_str)
6 changes: 5 additions & 1 deletion src/main/askai/core/support/shared_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from askai.core.askai_configs import configs
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import cache
from askai.core.component.geo_location import geo_location
from askai.core.component.recorder import recorder
from askai.core.engine.ai_engine import AIEngine
Expand Down Expand Up @@ -127,7 +128,10 @@ def create_engine(self, engine_name: str, model_name: str) -> AIEngine:
def create_context(self, token_limit: int) -> ChatContext:
"""TODO"""
if self._context is None:
self._context = ChatContext(token_limit, self.max_short_memory_size)
if configs.is_cache:
self._context = ChatContext.of(cache.read_context(), token_limit, self.max_short_memory_size)
else:
self._context = ChatContext(token_limit, self.max_short_memory_size)
return self._context

def create_memory(self, memory_key: str = "chat_history") -> ConversationBufferWindowMemory:
Expand Down
2 changes: 2 additions & 0 deletions src/main/askai/resources/prompts/accuracy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Classification Guidelines (rate from 0% to 100%):

- Revise the classifications for responses from the AI that contain irrelevant information to 'Yellow' instead of 'Red', as any additional information is still valued.

- "I don't know." may be a good response. Before classifying, check the chat context or provided contexts to make sure the AI understood the question, but, it does not have an answer. If that's the case, classify as 'Green'.

- Do not include any part of the question in your response.

- Indicate your classification choice ('Red', 'Orange', 'Yellow', 'Green', or 'Blue') followed by the reasoning behind your decision.
Expand Down
4 changes: 2 additions & 2 deletions src/main/askai/resources/prompts/assert.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
The (AI) provided an inadequate answer.
Please improve you response by addressing the following problems:
The (AI-Assistant) provided a bad answer.
Improve subsequent responses by addressing the following:

{problems}

0 comments on commit d93db97

Please sign in to comment.