Skip to content

Commit

Permalink
Refactor (#5)
Browse files Browse the repository at this point in the history
* renamed unifyai to unify so that we can use import unify instead of unifyai

* updated the _update_message_history method to have role arg

* Added getter/setter for model, provider and endpoint in the ChatBot class

* updated version

* updated README
  • Loading branch information
hello-fri-end committed Apr 15, 2024
1 parent 2290958 commit a77a69a
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 151 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install unifyai
## Basic Usage
```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand Down Expand Up @@ -88,7 +88,7 @@ To use the AsyncUnify client, simply import `AsyncUnify` instead
of `Unify` and use `await` with the `.generate` function.

```python
from unifyai import AsyncUnify
from unify import AsyncUnify
import os
import asyncio
async_unify = AsyncUnify(
Expand All @@ -110,7 +110,7 @@ You can enable streaming responses by setting `stream=True` in the `.generate` f

```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand All @@ -124,7 +124,7 @@ for chunk in stream:
It works in exactly the same way with Async clients.

```python
from unifyai import AsyncUnify
from unify import AsyncUnify
import os
import asyncio
async_unify = AsyncUnify(
Expand Down Expand Up @@ -152,7 +152,7 @@ As evidenced by our [benchmarks](https://unify.ai/hub), the optimal provider for

```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand All @@ -172,7 +172,7 @@ Dynamic routing works with both Synchronous and Asynchronous clients. For more i
Our `ChatBot` allows you to start an interactive chat session with any of our supported llm endpoints with only a few lines of code:

```python
from unifyai import ChatBot
from unify import ChatBot
agent = ChatBot(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[tool.poetry]
name = "unifyai"
version = "0.7.2"
packages = [{include = "unify"}]
version = "0.8.0"
readme = "README.md"
description = "A Python package for interacting with the Unify API"
authors = ["Unify <hello@unify.com>"]
repository = "https://github.com/unifyai/unify-llm-python"
repository = "https://github.com/unifyai/unify"

[tool.poetry.dependencies]
python = "^3.9"
Expand Down
4 changes: 4 additions & 0 deletions unify/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Unify python module."""

from unify.clients import AsyncUnify, Unify # noqa: F403
from unify.chat import ChatBot # noqa: F403
232 changes: 232 additions & 0 deletions unify/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import sys

from typing import Optional
from unify.clients import Unify


class ChatBot: # noqa: WPS338
"""Agent class represents an LLM chat agent."""

def __init__(
self,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
) -> None:
"""
Initializes the ChatBot object.
Args:
api_key (str, optional): API key for accessing the Unify API.
If None, it attempts to retrieve the API key from the
environment variable UNIFY_KEY.
Defaults to None.
endpoint (str, optional): Endpoint name in OpenAI API format:
<uploaded_by>/<model_name>@<provider_name>
Defaults to None.
model (str, optional): Name of the model. If None,
endpoint must be provided.
provider (str, optional): Name of the provider. If None,
endpoint must be provided.
Raises:
UnifyError: If the API key is missing.
"""
self._message_history = []
self._paused = False
self._client = Unify(
api_key=api_key,
endpoint=endpoint,
model=model,
provider=provider,
)

@property
def client(self) -> str:
"""
Get the client object. # noqa: DAR201.
Returns:
str: The model name.
"""
return self._client

def set_client(self, value: Unify) -> None:
"""
Set the model name. # noqa: DAR101.
Args:
value: The unify client.
"""
if isinstance(value, Unify):
self._client = value
else:
raise UnifyError("Invalid client!")


@property
def model(self) -> str:
"""
Get the model name. # noqa: DAR201.
Returns:
str: The model name.
"""
return self._client.model

def set_model(self, value: str) -> None:
"""
Set the model name. # noqa: DAR101.
Args:
value (str): The model name.
"""
self._client.set_model(value)
if self._client.provider:
self._client.set_endpoint("@".join([value, self._client.provider]))
else:
mode = self._client.endpoint.split("@")[1]
self._client.set_endpoint("@".join([value, mode]))

@property
def provider(self) -> Optional[str]:
"""
Get the provider name. # noqa :DAR201.
Returns:
str: The provider name.
"""
return self._client.provider

def set_provider(self, value: str) -> None:
"""
Set the provider name. # noqa: DAR101.
Args:
value (str): The provider name.
"""
self._client.set_provider(value)
self._client.set_endpoint("@".join([self._model, value]))

@property
def endpoint(self) -> str:
"""
Get the endpoint name. # noqa: DAR201.
Returns:
str: The endpoint name.
"""
return self._client.endpoint

def set_endpoint(self, value: str) -> None:
"""
Set the model name. # noqa: DAR101.
Args:
value (str): The endpoint name.
"""
self._client.set_endpoint(value)
self._client.set_model(value.split("@")[0])
self._client.set_provider(value.split("@")[1])

def _get_credits(self):
"""
Retrieves the current credit balance from associated with the UNIFY account.
Returns:
float: Current credit balance.
"""
return self._client.get_credit_balance()

def _process_input(self, inp: str, show_credits: bool, show_provider: bool):
"""
Processes the user input to generate AI response.
Args:
inp (str): User input message.
show_credits (bool): Whether to show credit consumption.
show_credits (bool): Whether to show provider used.
Yields:
str: Generated AI response chunks.
"""
self._update_message_history(role = "user", content = inp)
initial_credit_balance = self._get_credits()
stream = self._client.generate(
messages=self._message_history,
stream=True,
)
words = ""
for chunk in stream:
words += chunk
yield chunk

self._update_message_history(
role = "assistant",
content = words,
)
final_credit_balance = self._get_credits()
if show_credits:
sys.stdout.write(
"\n(spent {:.6f} credits)".format(
initial_credit_balance - final_credit_balance,
),
)
if show_provider:
sys.stdout.write("\n(provider: {})".format(self._client.provider))

def _update_message_history(self, role: str, content: str):
"""
Updates message history with user input.
Args:
role (str): Either "assistant" or "user".
content (str): User input message.
"""
self._message_history.append(
{
"role": role,
"content": content,
},
)

def clear_chat_history(self):
"""Clears the chat history."""
self._message_history.clear()

def run(self, show_credits: bool = False, show_provider: bool = False):
"""
Starts the chat interaction loop.
Args:
show_credits (bool, optional): Whether to show credit consumption.
Defaults to False.
show_provider (bool, optional): Whether to show the provider used.
Defaults to False.
"""
if not self._paused:
sys.stdout.write(
"Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n",
)
self.clear_chat_history()
else:
sys.stdout.write(
"Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n",
)
self._paused = False
while True:
sys.stdout.write("> ")
inp = input()
if inp == "quit":
self.clear_chat_history()
break
elif inp == "pause":
self._paused = True
break
for word in self._process_input(inp, show_credits, show_provider):
sys.stdout.write(word)
sys.stdout.flush()
sys.stdout.write("\n")
14 changes: 7 additions & 7 deletions unifyai/clients.py → unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import openai
import requests
from unifyai.exceptions import BadRequestError, UnifyError, status_error_map
from unifyai.utils import ( # noqa:WPS450
from unify.exceptions import BadRequestError, UnifyError, status_error_map
from unify.utils import ( # noqa:WPS450
_available_dynamic_modes,
_validate_api_key,
_validate_endpoint,
Expand Down Expand Up @@ -205,7 +205,7 @@ def _generate_stream(
)
for chunk in chat_completion:
content = chunk.choices[0].delta.content # type: ignore[union-attr]
self._provider = chunk.model.split("@")[-1] # type: ignore[union-attr]
self.set_provider(chunk.model.split("@")[-1]) # type: ignore[union-attr]
if content is not None:
yield content
except openai.APIStatusError as e:
Expand All @@ -222,9 +222,9 @@ def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
self._provider = chat_completion.model.split( # type: ignore[union-attr]
self.set_provider(chat_completion.model.split( # type: ignore[union-attr]
"@",
)[-1]
)[-1])

return chat_completion.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
Expand Down Expand Up @@ -429,7 +429,7 @@ async def _generate_stream(
stream=True,
)
async for chunk in async_stream: # type: ignore[union-attr]
self._provider = chunk.model.split("@")[-1]
self.set_provider(chunk.model.split("@")[-1])
yield chunk.choices[0].delta.content or ""
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None
Expand All @@ -445,7 +445,7 @@ async def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
self._provider = async_response.model.split("@")[-1] # type: ignore
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None
File renamed without changes.
4 changes: 2 additions & 2 deletions unifyai/tests.py → unify/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from types import AsyncGeneratorType, GeneratorType
from unittest.mock import MagicMock, patch

from unifyai.clients import AsyncUnify, Unify
from unifyai.exceptions import AuthenticationError, UnifyError
from unify.clients import AsyncUnify, Unify
from unify.exceptions import AuthenticationError, UnifyError


class TestUnify(unittest.TestCase):
Expand Down
File renamed without changes.
4 changes: 0 additions & 4 deletions unifyai/__init__.py

This file was deleted.

0 comments on commit a77a69a

Please sign in to comment.