From a77a69a81c4174fca4f144d3f89659537974f331 Mon Sep 17 00:00:00 2001 From: Anwaar Khalid Date: Mon, 15 Apr 2024 06:46:15 +0530 Subject: [PATCH] Refactor (#5) * renamed unifyai to unify so that we can use import unify instead of unifyai * updated the _update_message_history method to have role arg * Added getter/setter for model, provider and endpoint in the ChatBot class * updated version * updated README --- README.md | 12 +- pyproject.toml | 5 +- unify/__init__.py | 4 + unify/chat.py | 232 +++++++++++++++++++++++++++++++ {unifyai => unify}/clients.py | 14 +- {unifyai => unify}/exceptions.py | 0 {unifyai => unify}/tests.py | 4 +- {unifyai => unify}/utils.py | 0 unifyai/__init__.py | 4 - unifyai/chat.py | 130 ----------------- 10 files changed, 254 insertions(+), 151 deletions(-) create mode 100644 unify/__init__.py create mode 100644 unify/chat.py rename {unifyai => unify}/clients.py (96%) rename {unifyai => unify}/exceptions.py (100%) rename {unifyai => unify}/tests.py (97%) rename {unifyai => unify}/utils.py (100%) delete mode 100644 unifyai/__init__.py delete mode 100644 unifyai/chat.py diff --git a/README.md b/README.md index 3903b04..23eb3f8 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ pip install unifyai ## Basic Usage ```python import os -from unifyai import Unify +from unify import Unify unify = Unify( # This is the default and optional to include. api_key=os.environ.get("UNIFY_KEY"), @@ -88,7 +88,7 @@ To use the AsyncUnify client, simply import `AsyncUnify` instead of `Unify` and use `await` with the `.generate` function. ```python -from unifyai import AsyncUnify +from unify import AsyncUnify import os import asyncio async_unify = AsyncUnify( @@ -110,7 +110,7 @@ You can enable streaming responses by setting `stream=True` in the `.generate` f ```python import os -from unifyai import Unify +from unify import Unify unify = Unify( # This is the default and optional to include. api_key=os.environ.get("UNIFY_KEY"), @@ -124,7 +124,7 @@ for chunk in stream: It works in exactly the same way with Async clients. ```python -from unifyai import AsyncUnify +from unify import AsyncUnify import os import asyncio async_unify = AsyncUnify( @@ -152,7 +152,7 @@ As evidenced by our [benchmarks](https://unify.ai/hub), the optimal provider for ```python import os -from unifyai import Unify +from unify import Unify unify = Unify( # This is the default and optional to include. api_key=os.environ.get("UNIFY_KEY"), @@ -172,7 +172,7 @@ Dynamic routing works with both Synchronous and Asynchronous clients. For more i Our `ChatBot` allows you to start an interactive chat session with any of our supported llm endpoints with only a few lines of code: ```python -from unifyai import ChatBot +from unify import ChatBot agent = ChatBot( # This is the default and optional to include. api_key=os.environ.get("UNIFY_KEY"), diff --git a/pyproject.toml b/pyproject.toml index 0d60748..e551d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,11 @@ [tool.poetry] name = "unifyai" -version = "0.7.2" +packages = [{include = "unify"}] +version = "0.8.0" readme = "README.md" description = "A Python package for interacting with the Unify API" authors = ["Unify "] -repository = "https://github.com/unifyai/unify-llm-python" +repository = "https://github.com/unifyai/unify" [tool.poetry.dependencies] python = "^3.9" diff --git a/unify/__init__.py b/unify/__init__.py new file mode 100644 index 0000000..b45f41a --- /dev/null +++ b/unify/__init__.py @@ -0,0 +1,4 @@ +"""Unify python module.""" + +from unify.clients import AsyncUnify, Unify # noqa: F403 +from unify.chat import ChatBot # noqa: F403 diff --git a/unify/chat.py b/unify/chat.py new file mode 100644 index 0000000..abb0a06 --- /dev/null +++ b/unify/chat.py @@ -0,0 +1,232 @@ +import sys + +from typing import Optional +from unify.clients import Unify + + +class ChatBot: # noqa: WPS338 + """Agent class represents an LLM chat agent.""" + + def __init__( + self, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, + model: Optional[str] = None, + provider: Optional[str] = None, + ) -> None: + """ + Initializes the ChatBot object. + + Args: + api_key (str, optional): API key for accessing the Unify API. + If None, it attempts to retrieve the API key from the + environment variable UNIFY_KEY. + Defaults to None. + + endpoint (str, optional): Endpoint name in OpenAI API format: + /@ + Defaults to None. + + model (str, optional): Name of the model. If None, + endpoint must be provided. + + provider (str, optional): Name of the provider. If None, + endpoint must be provided. + Raises: + UnifyError: If the API key is missing. + """ + self._message_history = [] + self._paused = False + self._client = Unify( + api_key=api_key, + endpoint=endpoint, + model=model, + provider=provider, + ) + + @property + def client(self) -> str: + """ + Get the client object. # noqa: DAR201. + + Returns: + str: The model name. + """ + return self._client + + def set_client(self, value: Unify) -> None: + """ + Set the model name. # noqa: DAR101. + + Args: + value: The unify client. + """ + if isinstance(value, Unify): + self._client = value + else: + raise UnifyError("Invalid client!") + + + @property + def model(self) -> str: + """ + Get the model name. # noqa: DAR201. + + Returns: + str: The model name. + """ + return self._client.model + + def set_model(self, value: str) -> None: + """ + Set the model name. # noqa: DAR101. + + Args: + value (str): The model name. + """ + self._client.set_model(value) + if self._client.provider: + self._client.set_endpoint("@".join([value, self._client.provider])) + else: + mode = self._client.endpoint.split("@")[1] + self._client.set_endpoint("@".join([value, mode])) + + @property + def provider(self) -> Optional[str]: + """ + Get the provider name. # noqa :DAR201. + + Returns: + str: The provider name. + """ + return self._client.provider + + def set_provider(self, value: str) -> None: + """ + Set the provider name. # noqa: DAR101. + + Args: + value (str): The provider name. + """ + self._client.set_provider(value) + self._client.set_endpoint("@".join([self._model, value])) + + @property + def endpoint(self) -> str: + """ + Get the endpoint name. # noqa: DAR201. + + Returns: + str: The endpoint name. + """ + return self._client.endpoint + + def set_endpoint(self, value: str) -> None: + """ + Set the model name. # noqa: DAR101. + + Args: + value (str): The endpoint name. + """ + self._client.set_endpoint(value) + self._client.set_model(value.split("@")[0]) + self._client.set_provider(value.split("@")[1]) + + def _get_credits(self): + """ + Retrieves the current credit balance from associated with the UNIFY account. + + Returns: + float: Current credit balance. + """ + return self._client.get_credit_balance() + + def _process_input(self, inp: str, show_credits: bool, show_provider: bool): + """ + Processes the user input to generate AI response. + + Args: + inp (str): User input message. + show_credits (bool): Whether to show credit consumption. + show_credits (bool): Whether to show provider used. + + Yields: + str: Generated AI response chunks. + """ + self._update_message_history(role = "user", content = inp) + initial_credit_balance = self._get_credits() + stream = self._client.generate( + messages=self._message_history, + stream=True, + ) + words = "" + for chunk in stream: + words += chunk + yield chunk + + self._update_message_history( + role = "assistant", + content = words, + ) + final_credit_balance = self._get_credits() + if show_credits: + sys.stdout.write( + "\n(spent {:.6f} credits)".format( + initial_credit_balance - final_credit_balance, + ), + ) + if show_provider: + sys.stdout.write("\n(provider: {})".format(self._client.provider)) + + def _update_message_history(self, role: str, content: str): + """ + Updates message history with user input. + + Args: + role (str): Either "assistant" or "user". + content (str): User input message. + """ + self._message_history.append( + { + "role": role, + "content": content, + }, + ) + + def clear_chat_history(self): + """Clears the chat history.""" + self._message_history.clear() + + def run(self, show_credits: bool = False, show_provider: bool = False): + """ + Starts the chat interaction loop. + + Args: + show_credits (bool, optional): Whether to show credit consumption. + Defaults to False. + show_provider (bool, optional): Whether to show the provider used. + Defaults to False. + """ + if not self._paused: + sys.stdout.write( + "Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n", + ) + self.clear_chat_history() + else: + sys.stdout.write( + "Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n", + ) + self._paused = False + while True: + sys.stdout.write("> ") + inp = input() + if inp == "quit": + self.clear_chat_history() + break + elif inp == "pause": + self._paused = True + break + for word in self._process_input(inp, show_credits, show_provider): + sys.stdout.write(word) + sys.stdout.flush() + sys.stdout.write("\n") diff --git a/unifyai/clients.py b/unify/clients.py similarity index 96% rename from unifyai/clients.py rename to unify/clients.py index dea7998..2f2a93c 100644 --- a/unifyai/clients.py +++ b/unify/clients.py @@ -2,8 +2,8 @@ import openai import requests -from unifyai.exceptions import BadRequestError, UnifyError, status_error_map -from unifyai.utils import ( # noqa:WPS450 +from unify.exceptions import BadRequestError, UnifyError, status_error_map +from unify.utils import ( # noqa:WPS450 _available_dynamic_modes, _validate_api_key, _validate_endpoint, @@ -205,7 +205,7 @@ def _generate_stream( ) for chunk in chat_completion: content = chunk.choices[0].delta.content # type: ignore[union-attr] - self._provider = chunk.model.split("@")[-1] # type: ignore[union-attr] + self.set_provider(chunk.model.split("@")[-1]) # type: ignore[union-attr] if content is not None: yield content except openai.APIStatusError as e: @@ -222,9 +222,9 @@ def _generate_non_stream( messages=messages, # type: ignore[arg-type] stream=False, ) - self._provider = chat_completion.model.split( # type: ignore[union-attr] + self.set_provider(chat_completion.model.split( # type: ignore[union-attr] "@", - )[-1] + )[-1]) return chat_completion.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219 except openai.APIStatusError as e: @@ -429,7 +429,7 @@ async def _generate_stream( stream=True, ) async for chunk in async_stream: # type: ignore[union-attr] - self._provider = chunk.model.split("@")[-1] + self.set_provider(chunk.model.split("@")[-1]) yield chunk.choices[0].delta.content or "" except openai.APIStatusError as e: raise status_error_map[e.status_code](e.message) from None @@ -445,7 +445,7 @@ async def _generate_non_stream( messages=messages, # type: ignore[arg-type] stream=False, ) - self._provider = async_response.model.split("@")[-1] # type: ignore + self.set_provider(async_response.model.split("@")[-1]) # type: ignore return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219 except openai.APIStatusError as e: raise status_error_map[e.status_code](e.message) from None diff --git a/unifyai/exceptions.py b/unify/exceptions.py similarity index 100% rename from unifyai/exceptions.py rename to unify/exceptions.py diff --git a/unifyai/tests.py b/unify/tests.py similarity index 97% rename from unifyai/tests.py rename to unify/tests.py index 085efc2..880590c 100644 --- a/unifyai/tests.py +++ b/unify/tests.py @@ -3,8 +3,8 @@ from types import AsyncGeneratorType, GeneratorType from unittest.mock import MagicMock, patch -from unifyai.clients import AsyncUnify, Unify -from unifyai.exceptions import AuthenticationError, UnifyError +from unify.clients import AsyncUnify, Unify +from unify.exceptions import AuthenticationError, UnifyError class TestUnify(unittest.TestCase): diff --git a/unifyai/utils.py b/unify/utils.py similarity index 100% rename from unifyai/utils.py rename to unify/utils.py diff --git a/unifyai/__init__.py b/unifyai/__init__.py deleted file mode 100644 index a3c295e..0000000 --- a/unifyai/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Unify python module.""" - -from unifyai.clients import AsyncUnify, Unify # noqa: F403 -from unifyai.chat import ChatBot # noqa: F403 diff --git a/unifyai/chat.py b/unifyai/chat.py deleted file mode 100644 index 10ff07b..0000000 --- a/unifyai/chat.py +++ /dev/null @@ -1,130 +0,0 @@ -import sys - -from typing import Optional -from unifyai.clients import Unify - - -class ChatBot: # noqa: WPS338 - """Agent class represents an LLM chat agent.""" - - def __init__(self, api_key: Optional[str] = None, endpoint: Optional[str] = "llama-2-7b-chat@anyscale") -> None: - """ - Initializes the ChatBot object. - - Args: - api_key (optional, str): Your UNIFY key. - endpoint (optional, str): The endpoint for the chatbot. - """ - self._message_history = [] - self._endpoint = endpoint - self._api_key = api_key - self._paused = False - self._client = Unify( - api_key=self._api_key, - endpoint=self._endpoint, - ) - - def _get_credits(self): - """ - Retrieves the current credit balance from associated with the UNIFY account. - - Returns: - float: Current credit balance. - """ - return self._client.get_credit_balance() - - def _process_input(self, inp: str, show_credits: bool, show_provider: bool): - """ - Processes the user input to generate AI response. - - Args: - inp (str): User input message. - show_credits (bool): Whether to show credit consumption. - - Yields: - str: Generated AI response chunks. - """ - self._update_message_history(inp) - initial_credit_balance = self._get_credits() - stream = self._client.generate( - messages=self._message_history, - stream=True, - ) - words = "" - for chunk in stream: - words += chunk - yield chunk - - self._message_history.append( - { - "role": "assistant", - "content": words, - }, - ) - final_credit_balance = self._get_credits() - if show_credits: - sys.stdout.write( - "\n(spent {:.6f} credits)".format( - initial_credit_balance - final_credit_balance, - ), - ) - if show_provider: - sys.stdout.write("\n(provider: {})".format(self._client.provider)) - - def _update_message_history(self, inp): - """ - Updates message history with user input. - - Args: - inp (str): User input message. - """ - self._message_history.append( - { - "role": "user", - "content": inp, - }, - ) - - @property - def endpoint(self): - return self._endpoint - - @endpoint.setter - def endpoint(self, value): - self._endpoint = value - self._client.set_endpoint(self._endpoint) - - def clear_chat_history(self): - """Clears the chat history.""" - self._message_history.clear() - - def run(self, show_credits: bool = False, show_provider: bool = False): - """ - Starts the chat interaction loop. - - Args: - show_credits (bool, optional): Whether to show credit consumption. Defaults to False. - """ - if not self._paused: - sys.stdout.write( - "Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n", - ) - self.clear_chat_history() - else: - sys.stdout.write( - "Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n", - ) - self._paused = False - while True: - sys.stdout.write("> ") - inp = input() - if inp == "quit": - self.clear_chat_history() - break - elif inp == "pause": - self._paused = True - break - for word in self._process_input(inp, show_credits, show_provider): - sys.stdout.write(word) - sys.stdout.flush() - sys.stdout.write("\n")