In [None]:
%load_ext nb_black

In [None]:
import lmql
from enum import Enum
import inspect
import textwrap
import asyncio
import chromadb
from chromadb.utils import embedding_functions
import requests
import re
import os
import pandas as pd
from functools import lru_cache
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
    Any,
    Union,
    ClassVar,
    Dict,
    Generator,
    List,
    Optional,
    Protocol,
    Tuple,
    Type,
    Optional,
    TypeVar,
    Callable,
    AsyncGenerator,
    TypedDict,
    Generic,
    Coroutine,
    Set,
)
from getpass import getpass
from itertools import chain
from uuid import UUID, uuid4
from glob import glob
from pathlib import Path

DJ_URL = f"http://localhost:8000"

In [None]:
@lru_cache(1)
def get_chroma():
    return chromadb.Client()

In [None]:
@dataclass
class VectorStore:
    collection_name: str

    def __post_init__(self):
        ef = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name="all-MiniLM-L6-v2"
        )
        self.client = get_chroma()
        self.collection = self.client.get_or_create_collection(
            self.collection_name, embedding_function=ef
        )

In [None]:
OPENAI_API_KEY = getpass()
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

In [None]:
metrics_json = requests.get(
    f"{DJ_URL}/metrics",
).json()

metrics = pd.DataFrame(metrics_json)

In [None]:
dimensions = set(d.split(".")[0] for d in metrics.dimensions.sum())

In [None]:
dimensions = [
    requests.get(
        f"{DJ_URL}/nodes/{d}",
    ).json()
    for d in dimensions
]

In [None]:
dimensions = pd.DataFrame(dimensions)

In [None]:
dimensions_metrics = {}
for m, ds in zip(
    metrics.name,
    metrics.dimensions.apply(lambda l: {d.split(".")[0] for d in l}).tolist(),
):
    for d in ds:
        dimensions_metrics[d] = dimensions_metrics.get(d, [])
        dimensions_metrics[d].append(m)

In [None]:
dimensions["metrics"] = dimensions.name.map(dimensions_metrics)

In [None]:
metrics_vectorstore = VectorStore(collection_name="metrics")
dimensions_vectorstore = VectorStore(collection_name="dimensions")
knowledge_vectorstore = VectorStore(collection_name="knowledge")

In [None]:
def window_document(
    file_name: str, document_text: str, window_size: int = 200, overlap: int = 50
):
    """
    Splits a document into overlapping windows of fixed size.

    Args:
        document (str): The document to split.
        window_size (int): The word size of each window.
        overlap (int): The amount of word overlap between adjacent windows.

    Returns:
        List[str]: A list of overlapping windows.
    """

    document = re.split(r"\s+", document_text)
    title = (
        re.split(r"[._-]+", file_name)
        + re.split(r"\s+", document_text.split("\n")[0])[:10]
    )
    windows = []
    start = 0
    end = window_size
    while end <= len(document):
        windows.append(" ".join((title if start != 0 else []) + document[start:end]))
        start += window_size - overlap
        end += window_size - overlap
    if end > len(document) and start < len(document):
        windows.append(" ".join(title + document[start:]))
    return windows

In [None]:
knowledge_files = glob("../examples/knowledge/*.txt")

In [None]:
knowledge_doc_texts = {}
for kd in knowledge_files:
    with open(kd) as f:
        knowledge_doc_texts[".".join(Path(kd).name.split(".")[:-1])] = f.read()

In [None]:
knowledge_docs = []
for kd, doc in knowledge_doc_texts.items():
    for idx, passage in enumerate(window_document(kd, doc)):
        knowledge_docs.append(
            {
                "ids": kd + f"_{idx}",
                "documents": passage,
                "metadatas": {"file": kd, "part": idx},
            }
        )
knowledge_docs = pd.DataFrame(knowledge_docs)

In [None]:
knowledge_vectorstore.collection.add(**knowledge_docs.to_dict(orient="list"))

In [None]:
metric_docs = pd.DataFrame(
    [
        {
            "ids": str(m.id),
            "documents": m.description,
            "metadatas": {
                "name": m["name"],
                "query": m.query,
                "dimensions": str(m.dimensions),
            },
        }
        for _, m in metrics.iterrows()
    ]
)

metrics_vectorstore.collection.add(**metric_docs.to_dict(orient="list"))

In [None]:
dimension_docs = pd.DataFrame(
    [
        {
            "ids": str(d.node_revision_id),
            "documents": d.description,
            "metadatas": {
                "name": d["name"],
                "query": d.query,
                "metrics": str(d.metrics),
            },
        }
        for _, d in dimensions.iterrows()
    ]
)

dimensions_vectorstore.collection.add(**metric_docs.to_dict(orient="list"))

In [None]:
SOURCE_PATCH = {}

try:
    getsourcelines
except NameError:
    getsourcelines = inspect.getsourcelines


def monkey_patch_getsourcelines(object):
    if object in SOURCE_PATCH:
        return SOURCE_PATCH[object].splitlines(keepends=True), 0
    return getsourcelines(object)


inspect.getsourcelines = monkey_patch_getsourcelines

In [None]:
T = TypeVar("T")


def required_value(message: str, return_type: Type[T]) -> Callable[[], T]:
    def raise_message() -> T:
        raise ValueError(message)

    return raise_message


class Stringable(Protocol):
    def __str__(self) -> str:
        pass


SchemaDict = Dict[str, Union[Type[str], Type[int], "SchemaDict"]]


@dataclass
class ToolSchema:
    """
    Final answer value produced from an agent
    """

    schema_dict: TypedDict
    _compiled: bool = field(init=False, default=False)
    _body: Optional[str] = field(init=False, default=None)
    _where: bool = field(init=False, default=False)

    @property
    def body(self):
        self._compile()
        return self._body

    @property
    def code(self):
        self._compile()
        return (
            self.body.replace('\\"[', "")
            .replace("]", "")
            .replace('\\"', '"')
            .strip()[1:-1]
        )

    @property
    def where(self):
        self._compile()
        return self._where

    def _compile(self):
        if self._compiled:
            return
        schema_dict = self.schema_dict.__annotations__
        if not schema_dict:
            self._body = ""
            self._where = ""
            return
        where = []
        code = []
        prefix = self.schema_dict.__name__ + "_"

        def _helper(schema, key, end=False):
            if schema == int:
                variable = (prefix + key).upper()
                where.append(f'INT({variable}) and STOPS_AT({variable}, ",")')
                return variable
            if schema == str:
                variable = (prefix + key).upper()
                where.append(f"""STOPS_AT({variable}, '"')""")
                return variable
            if not isinstance(schema, dict):
                raise Exception(f"Unnacceptable type in schema: `{schema}`")
            result = "{{"
            for idx, (key, value) in enumerate(schema.items()):
                if "[" in key or "]" in key:
                    raise Exception("schema keys cannot have `[` or `]`")
                variable = _helper(value, key=key, end=idx == len(schema))
                quote = '\\"' if value == str else ""
                result += f'\\"{key}\\": {quote}[{variable}], '
            result = result[:-2] + "}}"
            return result

        self._body = _helper(schema_dict, key="")
        self._where = " and ".join(where)


@dataclass
class Tool:
    default_description: ClassVar[str]
    default_ref_name: ClassVar[str]
    input_schema: ClassVar[ToolSchema]
    model_identifier: Optional[str] = None
    description_: Optional[str] = None
    ref_name_: Optional[str] = None

    @property
    def description(self):
        return self.description_ or self.default_description

    @property
    def ref_name(self):
        return self.ref_name_ or self.default_ref_name

    async def __call__(self, input: dict) -> "Observation":
        raise NotImplementedError()


@dataclass
class Utterance:
    utterance_: Stringable
    marker: str = ""
    timestamp: datetime = field(default_factory=datetime.utcnow)
    context: str = ""
    parent_: Optional["Utterance"] = None
    utterance_id: UUID = field(default_factory=uuid4)
    session_: Optional["Session"] = field(default=None, init=False)

    def __post_init__(self):
        self.session = self.parent_.session

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, parent: "Utterance"):
        self.session = parent.session
        self.parent_ = parent

    def __str__(self):
        return self.marker + self.utterance

    def history(self, n: Optional[int] = None) -> Generator:
        n_ = n or float("inf")
        curr = self
        while n_ > 0 and (curr is not None):
            yield curr
            curr = curr.parent
            n_ -= 1

    @property
    def session(self):
        if self.session_ is not None:
            return self.session_
        if self.parent is not None:
            return self.parent.session
        return None

    @session.setter
    def session(self, session: "Session"):
        self.session_ = session

    @property
    def utterance(self):
        return str(self.utterance_)

In [None]:
@dataclass
class User(Utterance):
    """
    Utterance from a user
    """

    marker = "User: "


@dataclass
class Observation(Utterance):
    """
    Value produced from a tool
    """

    marker = "Observation: "
    tool: Tool = field(
        default_factory=required_value("`tool` is required for an Observation.", Tool)
    )


@dataclass
class Thought(Utterance):
    """
    Value produced from an agent
    """

    agent: "Agent" = field(
        default_factory=required_value(
            "`agent` is required for a Thought.", lambda: Agent()
        )
    )
    marker = "Thought: "


@dataclass
class Answer(Utterance):
    """
    Final answer value produced from an agent
    """

    agent: "Agent" = field(
        default_factory=required_value(
            "`agent` is required for a Answer.", lambda: Agent()
        )
    )
    marker = "Answer: "

In [None]:
class SessionStatus(Enum):
    DISCONNECTED = "DISCONNECTED"
    LIVE = "LIVE"
    TIMEOUT = "TIMEOUT"

In [None]:
@dataclass
class Session:
    agent: "Agent"  # sessions are with an agent
    agent_utterances: Set[
        Union[Type[Observation], Type[Thought], Type[Answer]]
    ] = field(
        default_factory=lambda: {Answer}
    )  # this determines how verbose the agent will be
    session_id: UUID = field(default_factory=uuid4)
    status: SessionStatus = SessionStatus.LIVE
    utterance: Optional[str] = field(default=None, init=False)
    timestamp: datetime = field(default_factory=datetime.utcnow)
    timeout: int = 60 * 10
    sessions: ClassVar[Dict[UUID, "Session"]] = {}

    def __post_init__(self):
        Session.sessions[self.session_id] = self

    async def check_quit(self, utterance: Utterance) -> bool:
        if utterance is None or utterance.utterance.strip() in ("", "quit", "exit"):
            self.status = SessionStatus.DISCONNECTED
            await agent.asend(None)  # tell the agent it's done
            return True
        return False

    async def __aiter__(self):
        agent_channel = self.agent()

        async def session_loop() -> AsyncGenerator[Utterance, Optional[Utterance]]:
            while True:
                # wait for user input
                user: User = yield
                # session is disconnected if a user utterance is none or empty
                if await self.check_quit(user):
                    return
                user.session = self
                self.utterance = user
                # await the agent's response being an answer
                # agent response might not be an answer if it is replying verbosely
                while not isinstance(self.utterance, Answer):
                    response: Union[
                        Observation, Thought, Answer
                    ] = await agent_channel.asend(self.utterance)
                    if await self.check_quit(response):
                        return
                    self.utterance = response

        return session_loop

In [None]:
@dataclass
class VectorStoreMemory:
    utterance: Optional[Utterance] = None
    vector_store: Optional[VectorStore] = None
    default_k: int = 3

    @property
    def session_id(self) -> Optional[UUID]:
        return self.utterance and self.utterance.session_id

    async def add_memories(self, utterances: List[Utterance]):
        for utterance in utterances:
            if self.session_id is not None and utterance.session_id != self.session_id:
                raise Exception("utterances belong to the same session as this memory!")
        if self.vector_store is None:
            self.vector_store = Chroma(str(self.session_id))
        await self.vector_store.coll

    async def search(self, query: str, k: Optional[int] = None):
        k = k or self.default_k


@dataclass
class Agent:
    description: str
    ref_name: str
    query: Callable[["Agent", Any, ...], Coroutine[Any, Any, Utterance]]
    tools: List[Type[Tool]]
    model_identifier: str
    memory: Optional[VectorStoreMemory] = None
    _run: Callable[[Any, ...], Coroutine[Any, Any, Utterance]] = field(
        default=None, init=False
    )

    def __post_init__(self):
        assert self.tools, "This agent requires some tools"

    async def __call__(self, session: Session) -> Utterance:
        raise NotImplementedError()

    async def run(self, *args, **kwargs):
        if self._run is None:
            self._run = self._compile_query(self.query)
        return await self._run(*args, **kwargs)

    def _compile_query(
        self, f: Callable[["Agent", Any, ...], Coroutine[Any, Any, Utterance]]
    ):
        sig = inspect.signature(f)
        assert (
            next(sig.parameters.values()) == "agent"
        ), "first parameter to query must be `agent`"
        source = (
            "async def _f"
            + str(sig)
            + ":\n"
            + ("    '''" + f.__doc__.format(**self.__dict__) + "\n    '''")
        )
        exec(source)
        SOURCE_PATCH[locals().get("_f")] = source
        f = lmql.query(locals().get("_f"))

        async def awrapper(*args, **kwargs):
            return (await f(self, *args, **kwargs))[0]

        return awrapper

    async def __call__(
        self, utterances: Set[Union[Type[Observation], Type[Thought], Type[Answer]]]
    ) -> AsyncGenerator[Optional[Utterance], Utterance]:
        raise NotImplementedError()

In [None]:
@dataclass
class KnowledgeSearchTool(Tool):
    default_description = "Search for knowledge documents."
    default_ref_name = "knowledge_search"
    input_schema = ToolSchema(TypedDict("KnowledgeQuery", {"query": str}))
    n_docs: int = 3
    threshold: float = 0.0

    async def __call__(self, input) -> Observation:
        query = input["query"]
        results = knowledge_vectorstore.collection.query(
            query_texts=query, n_results=self.n_docs
        )
        res = ""
        for meta, doc in zip(results["metadatas"], results["documents"]):
            res += f"{meta}: {doc}\n"
        return Observation(tool=self, utterance=res)

In [None]:
async def standard_query(agent, convo: str):
    '''
    argmax
        """
        The following is a conversation between a User and an AI Agent.
        The Agent is talkative and provides lots of specific details from its context.
        The Agent has Thoughts, uses Tools by providing Tool Input, and ultimately provides Answers.
        If the Agent cannot a question using its tools, it truthfully says it does not know.
        The Agent uses thoughful reasoning like so:

        Thought: use tool
        Tool: agent selects appropriate tool
        Tool Input: thoroughly descriptive input for the tool to work
        ===
        Thought: final answer
        Answer: agent describes the answer
        ===
        Thought: no answer
        Answer: Agent explains why it could not find an answer

        {tools_prompt}

        Conversation:
        {{convo}}
        """
        "Thought: [THOUGHT]\\n"
        thought = Thought(utterance = THOUGHT, agent = agent, parent = utterance)
        if THOUGHT == 'use tool':
            "Tool: [TOOL]\\n"
            {tool_body}
            observation.parent = thought
            return observation
        elif THOUGHT == 'final answer':
            "Answer: [ANSWER]\\n"
            return Answer(utterance = ANSWER, agent = agent)

        else:
            return Answer(utterance = "I cannot find an answer.", agent = agent)
    from
        "{model_identifier}"
    where
        THOUGHT in ["use tool", "final answer", "no answer"] and
        TOOL in {tool_names} and
        STOPS_AT(THOUGHT, "\\n") and
        STOPS_AT(TOOL, "\\n") and
        {tool_conditions}
    '''

In [None]:
class StandardAgent(Agent):
    "A standard agent that can answer queries and solve tasks with tools."

    def __init__(
        self,
        *,
        description: str = "",
        ref_name: str = "standard",
        query=standard_query,
        history_length: int = 3,
        history_utterances: Set[Type[Utterance]] = {User, Answer},
        **kwargs,
    ):
        super().__init__(
            query=query,
            description=description or StandardAgent.__doc__,
            ref_name=ref_name,
            **kwargs,
        )

        self.tools_prompt = "Here are the tools you choose from:\n" + "\n".join(
            f"            {tool.ref_name}: {tool.description}" for tool in self.tools
        )
        self.tool_refs = {tool.ref_name: tool for tool in self.tools}
        tool_body = []
        for tool in self.tools:
            tool_body.append(f"if TOOL=='{tool.ref_name}':")
            tool_body.append(f'                "Tool Input: {tool.input_schema.body}"')
            tool_body.append(f"                input_dict = {tool.input_schema.code}")
            tool_body.append(
                f"                observation = await agent.tool_refs.get(TOOL)(input_dict)"
            )
        self.tool_body = "\n".join(tool_body)
        self.tool_conditions = " and\n".join(
            tool.input_schema.where for tool in self.tools
        )
        self.tool_names = list(self.tool_refs.keys())

    async def __call__(
        self, utterances: Set[Union[Type[Observation], Type[Thought], Type[Answer]]]
    ) -> AsyncGenerator[Optional[Utterance], Optional[Utterance]]:
        utterance = yield
        # when we send the agent a None it ends
        if utterance is None:
            return

        history = []
        for utterance in utterance.history():
            if type(utterance) in self.history_utterances:
                history.append(utterance)
            if len(history) == self.history_length:
                break

        convo = "\n".join(str(u) for u in history[::-1]) + "\n"

        response: Utterance = await self.run(
            convo=convo,
        )

        response_origin = list(response.history())[-1]
        response_origin.parent = utterance

        if type(utterance) in utterances:
            yield utterance

In [None]:
tools = [KnowledgeSearchTool()]
agent = StandardAgent(model_identifier="openai/babbage", tools=tools)