diff --git a/predictionguard/client.py b/predictionguard/client.py index 00b911b..a1e44f3 100644 --- a/predictionguard/client.py +++ b/predictionguard/client.py @@ -12,11 +12,12 @@ from .src.toxicity import Toxicity from .src.pii import Pii from .src.injection import Injection +from .src.models import Models from .version import __version__ __all__ = [ "PredictionGuard", "Chat", "Completions", "Embeddings", "Tokenize", - "Translate", "Factuality", "Toxicity", "Pii", "Injection" + "Translate", "Factuality", "Toxicity", "Pii", "Injection", "Models" ] class PredictionGuard: @@ -80,6 +81,9 @@ def __init__( self.tokenize: Tokenize = Tokenize(self.api_key, self.url) """Tokenize generates tokens for input text.""" + self.models: Models = Models(self.api_key, self.url) + """Models lists all of the models available in the Prediction Guard API.""" + def _connect_client(self) -> None: # Prepare the proper headers. diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index 289fd8d..12532f8 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional, Union import urllib.request import urllib.parse -from warnings import warn import uuid from ..version import __version__ @@ -272,7 +271,7 @@ def stream_generator(url, headers, payload, stream): else: return return_dict(self.url, headers, payload) - def list_models(self) -> List[str]: + def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]: # Get the list of current models. headers = { "Content-Type": "application/json", @@ -280,6 +279,17 @@ def list_models(self) -> List[str]: "User-Agent": "Prediction Guard Python Client: " + __version__ } - response = requests.request("GET", self.url + "/chat/completions", headers=headers) + if capability != "chat-completion" and capability != "chat-with-image": + raise ValueError( + "Please enter a valid model type (chat-completion or chat-with-image)." + ) + else: + model_path = "/models/" + capability + + response = requests.request("GET", self.url + model_path, headers=headers) + + response_list = [] + for model in response.json()["data"]: + response_list.append(model["id"]) - return list(response.json()) \ No newline at end of file + return response_list \ No newline at end of file diff --git a/predictionguard/src/completions.py b/predictionguard/src/completions.py index 1237991..8a48f1d 100644 --- a/predictionguard/src/completions.py +++ b/predictionguard/src/completions.py @@ -2,7 +2,6 @@ import requests from typing import Any, Dict, List, Optional, Union -from warnings import warn from ..version import __version__ @@ -110,6 +109,10 @@ def list_models(self) -> List[str]: "User-Agent": "Prediction Guard Python Client: " + __version__, } - response = requests.request("GET", self.url + "/completions", headers=headers) + response = requests.request("GET", self.url + "/models/completion", headers=headers) - return list(response.json()) + response_list = [] + for model in response.json()["data"]: + response_list.append(model["id"]) + + return response_list diff --git a/predictionguard/src/embeddings.py b/predictionguard/src/embeddings.py index 3268fd1..cdf1419 100644 --- a/predictionguard/src/embeddings.py +++ b/predictionguard/src/embeddings.py @@ -4,7 +4,7 @@ import base64 import requests -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional import urllib.request import urllib.parse import uuid @@ -174,7 +174,7 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction): pass raise ValueError("Could not generate embeddings. " + err) - def list_models(self) -> List[str]: + def list_models(self, capability: Optional[str] = "embedding") -> List[str]: # Get the list of current models. headers = { "Content-Type": "application/json", @@ -182,6 +182,18 @@ def list_models(self) -> List[str]: "User-Agent": "Prediction Guard Python Client: " + __version__, } - response = requests.request("GET", self.url + "/embeddings", headers=headers) + if capability != "embedding" and capability != "embedding-with-image": + raise ValueError( + "Please enter a valid models type " + "(embedding or embedding-with-image)." + ) + else: + model_path = "/models/" + capability + + response = requests.request("GET", self.url + model_path, headers=headers) + + response_list = [] + for model in response.json()["data"]: + response_list.append(model["id"]) - return list(response.json()) + return response_list diff --git a/predictionguard/src/models.py b/predictionguard/src/models.py new file mode 100644 index 0000000..c00a060 --- /dev/null +++ b/predictionguard/src/models.py @@ -0,0 +1,89 @@ +import requests +from typing import Any, Dict, Optional + +from ..version import __version__ + + +class Models: + """Models lists all the models available in the Prediction Guard Platform. + + Usage:: + + import os + import json + + from predictionguard import PredictionGuard + + # Set your Prediction Guard token as an environmental variable. + os.environ["PREDICTIONGUARD_API_KEY"] = "" + + client = PredictionGuard() + + response = client.models.list() + + print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": "))) + """ + + def __init__(self, api_key, url): + self.api_key = api_key + self.url = url + + def list(self, capability: Optional[str] = "") -> Dict[str, Any]: + """ + Creates a models list request in the Prediction Guard REST API. + + :param capability: The capability of models to list. + :return: A dictionary containing the metadata of all the models. + """ + + # Run _check_injection + choices = self._list_models(capability) + return choices + + def _list_models(self, capability): + """ + Function to list available models. + """ + + capabilities = [ + "chat-completion", "chat-with-image", "completion", + "embedding", "embedding-with-image", "tokenize" + ] + + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + "User-Agent": "Prediction Guard Python Client: " + __version__, + } + + models_path = "/models" + if capability != "": + if capability not in capabilities: + raise ValueError( + "If specifying a capability, please use one of the following: " + + ", ".join(capabilities) + ) + else: + models_path += "/" + capability + + response = requests.request( + "GET", self.url + models_path, headers=headers + ) + + if response.status_code == 200: + ret = response.json() + return ret + elif response.status_code == 429: + raise ValueError( + "Could not connect to Prediction Guard API. " + "Too many requests, rate limit or quota exceeded." + ) + else: + # Check if there is a json body in the response. Read that in, + # print out the error field in the json body, and raise an exception. + err = "" + try: + err = response.json()["error"] + except Exception: + pass + raise ValueError("Could not check for injection. " + err) diff --git a/predictionguard/src/tokenize.py b/predictionguard/src/tokenize.py index 8bacd25..919422c 100644 --- a/predictionguard/src/tokenize.py +++ b/predictionguard/src/tokenize.py @@ -89,3 +89,19 @@ def _create_tokens(self, model, input): except Exception: pass raise ValueError("Could not generate tokens. " + err) + + def list_models(self): + # Get the list of current models. + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + "User-Agent": "Prediction Guard Python Client: " + __version__ + } + + response = requests.request("GET", self.url + "/models/tokenize", headers=headers) + + response_list = [] + for model in response.json()["data"]: + response_list.append(model["id"]) + + return response_list diff --git a/tests/test_chat.py b/tests/test_chat.py index 717516d..7c45a71 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -193,3 +193,4 @@ def test_chat_completions_list_models(): response = test_client.chat.completions.list_models() assert len(response) > 0 + assert type(response[0]) is str diff --git a/tests/test_completions.py b/tests/test_completions.py index 3029a18..778583b 100644 --- a/tests/test_completions.py +++ b/tests/test_completions.py @@ -32,3 +32,4 @@ def test_completions_list_models(): response = test_client.completions.list_models() assert len(response) > 0 + assert type(response[0]) is str diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 73aa24b..68edec0 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,9 +1,6 @@ import os -import re import base64 -import pytest - from predictionguard import PredictionGuard @@ -215,3 +212,4 @@ def test_embeddings_list_models(): response = test_client.embeddings.list_models() assert len(response) > 0 + assert type(response[0]) is str \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..80d9079 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,76 @@ +from predictionguard import PredictionGuard + + +def test_models_list(): + test_client = PredictionGuard() + + response = test_client.models.list() + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_chat_completion(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="chat-completion" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_chat_with_image(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="chat-with-image" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_completion(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="completion" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_embedding(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="embedding" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_embedding_with_image(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="embedding-with-image" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str + + +def test_models_list_tokenize(): + test_client = PredictionGuard() + + response = test_client.models.list( + capability="tokenize" + ) + + assert len(response["data"]) > 0 + assert type(response["data"][0]["id"]) is str diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 6276a84..39c04be 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -14,3 +14,11 @@ def test_tokenize_create(): assert len(response) > 0 assert type(response["tokens"][0]["id"]) is int + +def test_tokenize_list(): + test_client = PredictionGuard() + + response = test_client.tokenize.list_models() + + assert len(response) > 0 + assert type(response[0]) is str