Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
guillesanbri committed Apr 29, 2024
1 parent ea1f7bd commit 341845f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
19 changes: 10 additions & 9 deletions unify/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys

from typing import Optional
from typing import Generator, Optional, List, Dict
from unify.clients import Unify
from unify.exceptions import UnifyError


class ChatBot: # noqa: WPS338
Expand Down Expand Up @@ -35,7 +36,7 @@ def __init__(
Raises:
UnifyError: If the API key is missing.
"""
self._message_history = []
self._message_history: List[Dict[str, str]] = []
self._paused = False
self._client = Unify(
api_key=api_key,
Expand All @@ -45,7 +46,7 @@ def __init__(
)

@property
def client(self) -> str:
def client(self) -> Unify:
"""
Get the client object. # noqa: DAR201.
Expand Down Expand Up @@ -108,7 +109,7 @@ def set_provider(self, value: str) -> None:
value (str): The provider name.
"""
self._client.set_provider(value)
self._client.set_endpoint("@".join([self._model, value]))
self._client.set_endpoint("@".join([self._client._model, value]))

@property
def endpoint(self) -> str:
Expand All @@ -131,7 +132,7 @@ def set_endpoint(self, value: str) -> None:
self._client.set_model(value.split("@")[0])
self._client.set_provider(value.split("@")[1])

def _get_credits(self):
def _get_credits(self) -> float:
"""
Retrieves the current credit balance from associated with the UNIFY account.
Expand All @@ -140,7 +141,7 @@ def _get_credits(self):
"""
return self._client.get_credit_balance()

def _process_input(self, inp: str, show_credits: bool, show_provider: bool):
def _process_input(self, inp: str, show_credits: bool, show_provider: bool) -> Generator[str, None, None]:
"""
Processes the user input to generate AI response.
Expand Down Expand Up @@ -177,7 +178,7 @@ def _process_input(self, inp: str, show_credits: bool, show_provider: bool):
if show_provider:
sys.stdout.write("\n(provider: {})".format(self._client.provider))

def _update_message_history(self, role: str, content: str):
def _update_message_history(self, role: str, content: str) -> None:
"""
Updates message history with user input.
Expand All @@ -192,11 +193,11 @@ def _update_message_history(self, role: str, content: str):
},
)

def clear_chat_history(self):
def clear_chat_history(self) -> None:
"""Clears the chat history."""
self._message_history.clear()

def run(self, show_credits: bool = False, show_provider: bool = False):
def run(self, show_credits: bool = False, show_provider: bool = False) -> None:
"""
Starts the chat interaction loop.
Expand Down
2 changes: 1 addition & 1 deletion unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate( # noqa: WPS234, WPS211
return self._generate_stream(contents, self._endpoint)
return self._generate_non_stream(contents, self._endpoint)

def get_credit_balance(self) -> Optional[int]:
def get_credit_balance(self) -> float:
# noqa: DAR201, DAR401
"""
Get the remaining credits left on your account.
Expand Down
10 changes: 5 additions & 5 deletions unify/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import requests
import json
from typing import Optional, Tuple
from typing import Optional, Tuple, List

from unify.exceptions import UnifyError

Expand All @@ -21,20 +21,20 @@
_base_url = "https://api.unify.ai/v0"


def _res_to_list(response):
def _res_to_list(response: requests.Response) -> List[str]:
return json.loads(response.text)


def list_models():
def list_models() -> List[str]:
return _res_to_list(requests.get(_base_url + "/models"))


def list_endpoints(model: str):
def list_endpoints(model: str) -> List[str]:
url = _base_url + "/endpoints_of"
return _res_to_list(requests.get(url, params={"model": model}))


def list_providers(model: str):
def list_providers(model: str) -> List[str]:
url = _base_url + "/providers_of"
return _res_to_list(requests.get(url, params={"model": model}))

Expand Down

0 comments on commit 341845f

Please sign in to comment.