Skip to content

Commit

Permalink
Added UNIFY_KEY secret to tests.yml and fixed flake8erros
Browse files Browse the repository at this point in the history
  • Loading branch information
hello-fri-end committed Mar 29, 2024
1 parent a87a978 commit cb68d99
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,5 @@ jobs:
poetry install --with dev
- name: Run unit tests
run: poetry run python -m unittest discover
env:
UNIFY_KEY: ${{ secrets.UNIFY_KEY }}
40 changes: 20 additions & 20 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unify.exceptions import UnifyError, status_error_map


def validate_api_key(api_key: Optional[str]) -> str:
def _validate_api_key(api_key: Optional[str]) -> str:
if api_key is None:
api_key = os.environ.get("UNIFY_KEY")
if api_key is None:
Expand All @@ -21,19 +21,19 @@ class Unify:
def __init__(
self,
api_key: Optional[str] = None,
) -> None:
) -> None: # noqa: DAR101, DAR401
"""Initialize the Unify client.
Args:
api_key (str, optional): API key for accessing the Unify API.
If None, it attempts to retrieve the API key from the
environment variable UNIFY_KEY.
Defaults to None.
If None, it attempts to retrieve the API key from the
environment variable UNIFY_KEY.
Defaults to None.
Raises:
UnifyError: If the API key is missing.
"""
self.api_key = validate_api_key(api_key)
self.api_key = _validate_api_key(api_key)
try:
self.client = openai.OpenAI(
base_url="https://api.unify.ai/v0/",
Expand All @@ -42,23 +42,23 @@ def __init__(
except openai.OpenAIError as e:
raise UnifyError(f"Failed to initialize Unify client: {str(e)}")

def generate(
def generate( # noqa: WPS234
self,
messages: Union[str, List[Dict[str, str]]],
model: str = "llama-2-13b-chat",
provider: str = "anyscale",
stream: bool = False,
) -> Union[Generator[str, None, None], str]:
) -> Union[Generator[str, None, None], str]: # noqa: DAR101, DAR201
"""Generate content using the Unify API.
Args:
messages (Union[str, List[Dict[str, str]]]): A single prompt as a
string or a dictionary containing the conversation history.
string or a dictionary containing the conversation history.
model (str): The name of the model. Defaults to "llama-2-13b-chat".
provider (str): The provider of the model. Defaults to "anyscale".
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Defaults to False.
If False, generates content as a single response.
Defaults to False.
Returns:
Union[Generator[str, None, None], str]: If stream is True,
Expand Down Expand Up @@ -108,7 +108,7 @@ def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
return chat_completion.choices[0].message.content.strip(" ") # type: ignore # noqa: E501
return chat_completion.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None

Expand All @@ -119,7 +119,7 @@ class AsyncUnify:
def __init__(
self,
api_key: Optional[str] = None,
) -> None:
) -> None: # noqa:DAR101, DAR401
"""Initialize the AsyncUnify client.
Args:
Expand All @@ -131,7 +131,7 @@ def __init__(
Raises:
UnifyError: If the API key is missing.
"""
self.api_key = validate_api_key(api_key)
self.api_key = _validate_api_key(api_key)
try:
self.client = openai.AsyncOpenAI(
base_url="https://api.unify.ai/v0/",
Expand All @@ -140,28 +140,28 @@ def __init__(
except openai.APIStatusError as e:
raise UnifyError(f"Failed to initialize Unify client: {str(e)}")

async def generate(
async def generate( # noqa: WPS234
self,
messages: Union[str, List[Dict[str, str]]],
model: str = "llama-2-13b-chat",
provider: str = "anyscale",
stream: bool = False,
) -> Union[AsyncGenerator[str, None], str]:
) -> Union[AsyncGenerator[str, None], str]: # noqa: DAR101, DAR201
"""Generate content asynchronously using the Unify API.
Args:
messages (Union[str, List[Dict[str, str]]]): A single prompt as a string
or a dictionary containing the conversation history.
or a dictionary containing the conversation history.
model (str): The name of the model.
provider (str): The provider of the model.
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Defaults to False.
Defaults to False.
Returns:
Union[AsyncGenerator[str, None], List[str]]: If stream is True,
returns an asynchronous generator yielding chunks of content.
If stream is False, returns a list of string responses.
If stream is False, returns a list of string responses.
Raises:
UnifyError: If an error occurs during content generation.
Expand Down Expand Up @@ -206,6 +206,6 @@ async def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None
18 changes: 9 additions & 9 deletions unify/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
class UnifyError(Exception):
pass
"""Base class for all custom exceptions in the Unify application."""


class BadRequestError(UnifyError):
pass
"""Exception raised for HTTP 400 Bad Request errors."""


class AuthenticationError(UnifyError):
pass
"""Exception raised for HTTP 401 Unauthorized errors."""


class PermissionDeniedError(UnifyError):
pass
"""Exception raised for HTTP 403 Forbidden errors."""


class NotFoundError(UnifyError):
pass
"""Exception raised for HTTP 404 Not Found errors."""


class ConflictError(UnifyError):
pass
"""Exception raised for HTTP 409 Conflict errors."""


class UnprocessableEntityError(UnifyError):
pass
"""Exception raised for HTTP 422 Unprocessable Entity errors."""


class RateLimitError(UnifyError):
pass
"""Exception raised for HTTP 429 Too Many Requests errors."""


class InternalServerError(UnifyError):
pass
"""Exception raised for HTTP 500 Internal Server Error errors."""


status_error_map = {
Expand Down

0 comments on commit cb68d99

Please sign in to comment.