Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion predictionguard/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from .src.toxicity import Toxicity
from .src.pii import Pii
from .src.injection import Injection
from .src.models import Models
from .version import __version__

__all__ = [
"PredictionGuard", "Chat", "Completions", "Embeddings", "Tokenize",
"Translate", "Factuality", "Toxicity", "Pii", "Injection"
"Translate", "Factuality", "Toxicity", "Pii", "Injection", "Models"
]

class PredictionGuard:
Expand Down Expand Up @@ -80,6 +81,9 @@ def __init__(
self.tokenize: Tokenize = Tokenize(self.api_key, self.url)
"""Tokenize generates tokens for input text."""

self.models: Models = Models(self.api_key, self.url)
"""Models lists all of the models available in the Prediction Guard API."""

def _connect_client(self) -> None:

# Prepare the proper headers.
Expand Down
18 changes: 14 additions & 4 deletions predictionguard/src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Dict, List, Optional, Union
import urllib.request
import urllib.parse
from warnings import warn
import uuid

from ..version import __version__
Expand Down Expand Up @@ -272,14 +271,25 @@ def stream_generator(url, headers, payload, stream):
else:
return return_dict(self.url, headers, payload)

def list_models(self) -> List[str]:
def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]:
# Get the list of current models.
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__
}

response = requests.request("GET", self.url + "/chat/completions", headers=headers)
if capability != "chat-completion" and capability != "chat-with-image":
raise ValueError(
"Please enter a valid model type (chat-completion or chat-with-image)."
)
else:
model_path = "/models/" + capability

response = requests.request("GET", self.url + model_path, headers=headers)

response_list = []
for model in response.json()["data"]:
response_list.append(model["id"])

return list(response.json())
return response_list
9 changes: 6 additions & 3 deletions predictionguard/src/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import requests
from typing import Any, Dict, List, Optional, Union
from warnings import warn

from ..version import __version__

Expand Down Expand Up @@ -110,6 +109,10 @@ def list_models(self) -> List[str]:
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

response = requests.request("GET", self.url + "/completions", headers=headers)
response = requests.request("GET", self.url + "/models/completion", headers=headers)

return list(response.json())
response_list = []
for model in response.json()["data"]:
response_list.append(model["id"])

return response_list
20 changes: 16 additions & 4 deletions predictionguard/src/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64

import requests
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional
import urllib.request
import urllib.parse
import uuid
Expand Down Expand Up @@ -174,14 +174,26 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):
pass
raise ValueError("Could not generate embeddings. " + err)

def list_models(self) -> List[str]:
def list_models(self, capability: Optional[str] = "embedding") -> List[str]:
# Get the list of current models.
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

response = requests.request("GET", self.url + "/embeddings", headers=headers)
if capability != "embedding" and capability != "embedding-with-image":
raise ValueError(
"Please enter a valid models type "
"(embedding or embedding-with-image)."
)
else:
model_path = "/models/" + capability

response = requests.request("GET", self.url + model_path, headers=headers)

response_list = []
for model in response.json()["data"]:
response_list.append(model["id"])

return list(response.json())
return response_list
89 changes: 89 additions & 0 deletions predictionguard/src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import requests
from typing import Any, Dict, Optional

from ..version import __version__


class Models:
"""Models lists all the models available in the Prediction Guard Platform.

Usage::

import os
import json

from predictionguard import PredictionGuard

# Set your Prediction Guard token as an environmental variable.
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"

client = PredictionGuard()

response = client.models.list()

print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
"""

def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

def list(self, capability: Optional[str] = "") -> Dict[str, Any]:
"""
Creates a models list request in the Prediction Guard REST API.

:param capability: The capability of models to list.
:return: A dictionary containing the metadata of all the models.
"""

# Run _check_injection
choices = self._list_models(capability)
return choices

def _list_models(self, capability):
"""
Function to list available models.
"""

capabilities = [
"chat-completion", "chat-with-image", "completion",
"embedding", "embedding-with-image", "tokenize"
]

headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

models_path = "/models"
if capability != "":
if capability not in capabilities:
raise ValueError(
"If specifying a capability, please use one of the following: "
+ ", ".join(capabilities)
)
else:
models_path += "/" + capability

response = requests.request(
"GET", self.url + models_path, headers=headers
)

if response.status_code == 200:
ret = response.json()
return ret
elif response.status_code == 429:
raise ValueError(
"Could not connect to Prediction Guard API. "
"Too many requests, rate limit or quota exceeded."
)
else:
# Check if there is a json body in the response. Read that in,
# print out the error field in the json body, and raise an exception.
err = ""
try:
err = response.json()["error"]
except Exception:
pass
raise ValueError("Could not check for injection. " + err)
16 changes: 16 additions & 0 deletions predictionguard/src/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ def _create_tokens(self, model, input):
except Exception:
pass
raise ValueError("Could not generate tokens. " + err)

def list_models(self):
# Get the list of current models.
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__
}

response = requests.request("GET", self.url + "/models/tokenize", headers=headers)

response_list = []
for model in response.json()["data"]:
response_list.append(model["id"])

return response_list
1 change: 1 addition & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,4 @@ def test_chat_completions_list_models():
response = test_client.chat.completions.list_models()

assert len(response) > 0
assert type(response[0]) is str
1 change: 1 addition & 0 deletions tests/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def test_completions_list_models():
response = test_client.completions.list_models()

assert len(response) > 0
assert type(response[0]) is str
4 changes: 1 addition & 3 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
import re
import base64

import pytest

from predictionguard import PredictionGuard


Expand Down Expand Up @@ -215,3 +212,4 @@ def test_embeddings_list_models():
response = test_client.embeddings.list_models()

assert len(response) > 0
assert type(response[0]) is str
76 changes: 76 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from predictionguard import PredictionGuard


def test_models_list():
test_client = PredictionGuard()

response = test_client.models.list()

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_chat_completion():
test_client = PredictionGuard()

response = test_client.models.list(
capability="chat-completion"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_chat_with_image():
test_client = PredictionGuard()

response = test_client.models.list(
capability="chat-with-image"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_completion():
test_client = PredictionGuard()

response = test_client.models.list(
capability="completion"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_embedding():
test_client = PredictionGuard()

response = test_client.models.list(
capability="embedding"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_embedding_with_image():
test_client = PredictionGuard()

response = test_client.models.list(
capability="embedding-with-image"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str


def test_models_list_tokenize():
test_client = PredictionGuard()

response = test_client.models.list(
capability="tokenize"
)

assert len(response["data"]) > 0
assert type(response["data"][0]["id"]) is str
8 changes: 8 additions & 0 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ def test_tokenize_create():
assert len(response) > 0
assert type(response["tokens"][0]["id"]) is int


def test_tokenize_list():
test_client = PredictionGuard()

response = test_client.tokenize.list_models()

assert len(response) > 0
assert type(response[0]) is str
Loading