Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added fixtures/test_audio.wav
Binary file not shown.
6 changes: 6 additions & 0 deletions fixtures/test_csv.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
John,Doe,120 jefferson st.,Riverside, NJ, 08075
Jack,McGinnis,220 hobo Av.,Phila, PA,09119
"John ""Da Man""",Repici,120 Jefferson St.,Riverside, NJ,08075
Stephen,Tyler,"7452 Terrace ""At the Plaza"" road",SomeTown,SD, 91234
,Blankman,,SomeTown, SD, 00298
"Joan ""the bone"", Anne",Jet,"9th, at Terrace plc",Desert City,CO,00123
Binary file added fixtures/test_pdf.pdf
Binary file not shown.
14 changes: 11 additions & 3 deletions predictionguard/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import requests
from typing import Optional

from .src.audio import Audio
from .src.chat import Chat
from .src.completions import Completions
from .src.documents import Documents
from .src.embeddings import Embeddings
from .src.rerank import Rerank
from .src.tokenize import Tokenize
Expand All @@ -17,9 +19,9 @@
from .version import __version__

__all__ = [
"PredictionGuard", "Chat", "Completions", "Embeddings", "Rerank",
"Tokenize", "Translate", "Factuality", "Toxicity", "Pii", "Injection",
"Models"
"PredictionGuard", "Chat", "Completions", "Embeddings",
"Audio", "Documents", "Rerank", "Tokenize", "Translate",
"Factuality", "Toxicity", "Pii", "Injection", "Models"
]

class PredictionGuard:
Expand Down Expand Up @@ -65,6 +67,12 @@ def __init__(
self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
"""Embedding generates chat completions based on a conversation history."""

self.audio: Audio = Audio(self.api_key, self.url)
"""Audio allows for the transcription of audio files."""

self.documents: Documents = Documents(self.api_key, self.url)
"""Documents allows you to extract text from various document file types."""

self.rerank: Rerank = Rerank(self.api_key, self.url)
"""Rerank sorts text inputs by semantic relevance to a specified query."""

Expand Down
98 changes: 98 additions & 0 deletions predictionguard/src/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json

import requests
from typing import Any, Dict, Optional

from ..version import __version__


class Audio:
"""Audio generates a response based on audio data.

Usage::

import os
import json

from predictionguard import PredictionGuard

# Set your Prediction Guard token as an environmental variable.
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"

client = PredictionGuard()

result = client.audio.transcriptions.create(
model="whisper-3-large-instruct", file=sample_audio.wav
)

print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": ")))
"""

def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url)

class AudioTranscriptions:
def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

def create(
self,
model: str,
file: str
) -> Dict[str, Any]:
"""
Creates a audio transcription request to the Prediction Guard /audio/transcriptions API

:param model: The model to use
:param file: Audio file to be transcribed
:result: A dictionary containing the transcribed text.
"""

# Create a list of tuples, each containing all the parameters for
# a call to _transcribe_audio
args = (model, file)

# Run _transcribe_audio
choices = self._transcribe_audio(*args)
return choices

def _transcribe_audio(self, model, file):
"""
Function to transcribe an audio file.
"""

headers = {
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

with open(file, "rb") as audio_file:
files = {"file": (file, audio_file, "audio/wav")}
data = {"model": model}

response = requests.request(
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
)

# If the request was successful, print the proxies.
if response.status_code == 200:
ret = response.json()
return ret
elif response.status_code == 429:
raise ValueError(
"Could not connect to Prediction Guard API. "
"Too many requests, rate limit or quota exceeded."
)
else:
# Check if there is a json body in the response. Read that in,
# print out the error field in the json body, and raise an exception.
err = ""
try:
err = response.json()["error"]
except Exception:
pass
raise ValueError("Could not transcribe the audio file. " + err)
5 changes: 5 additions & 0 deletions predictionguard/src/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create(
prompt: Union[str, List[str]],
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, int]] = None,
max_completion_tokens: Optional[int] = 100,
Expand All @@ -40,6 +41,7 @@ def create(
:param prompt: The prompt(s) to generate completions for.
:param input: A dictionary containing the PII and injection arguments.
:param output: A dictionary containing the consistency, factuality, and toxicity arguments.
:param echo: A boolean indicating whether to echo the prompt(s) to the output.
:param frequency_penalty: The frequency penalty to use.
:param logit_bias: The logit bias to use.
:param max_completion_tokens: The maximum number of tokens to generate in the completion(s).
Expand Down Expand Up @@ -68,6 +70,7 @@ def create(
prompt,
input,
output,
echo,
frequency_penalty,
logit_bias,
max_completion_tokens,
Expand All @@ -91,6 +94,7 @@ def _generate_completion(
prompt,
input,
output,
echo,
frequency_penalty,
logit_bias,
max_completion_tokens,
Expand Down Expand Up @@ -165,6 +169,7 @@ def stream_generator(url, headers, payload, stream):
payload_dict = {
"model": model,
"prompt": prompt,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"max_completion_tokens": max_completion_tokens,
Expand Down
89 changes: 89 additions & 0 deletions predictionguard/src/documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
from pyexpat import model

import requests
from typing import Any, Dict, Optional

from ..version import __version__


class Documents:
"""Documents allows you to extract text from various document file types.

Usage::

from predictionguard import PredictionGuard

# Set your Prediction Guard token as an environmental variable.
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"

client = PredictionGuard()

response = client.documents.extract.create(
file="sample.pdf"
)

print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": ")))
"""

def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url)

class DocumentsExtract:
def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

def create(
self,
file: str
) -> Dict[str, Any]:
"""
Creates a documents request to the Prediction Guard /documents/extract API

:param file: Document to be parsed
:result: A dictionary containing the title, content, and length of the document.
"""

# Run _extract_documents
choices = self._extract_documents(file)
return choices

def _extract_documents(self, file):
"""
Function to extract a document.
"""

headers = {
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

with open(file, "rb") as doc_file:
files = {"file": (file, doc_file)}

response = requests.request(
"POST", self.url + "/documents/extract", headers=headers, files=files
)

# If the request was successful, print the proxies.
if response.status_code == 200:
ret = response.json()
return ret
elif response.status_code == 429:
raise ValueError(
"Could not connect to Prediction Guard API. "
"Too many requests, rate limit or quota exceeded."
)
else:
# Check if there is a json body in the response. Read that in,
# print out the error field in the json body, and raise an exception.
err = ""
try:
err = response.json()["error"]
except Exception:
pass
raise ValueError("Could not extract document. " + err)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ build-backend = "flit_core.buildapi"

[project]
name = "predictionguard"
authors = [{name = "Daniel Whitenack", email = "dan@predictionguard.com"}]
authors = [
{name = "Daniel Whitenack", email = "dan@predictionguard.com"},
{name = "Jacob Mansdorfer", email = "jacob@predictionguard.com"}
]
readme = "README.md"
license = {file = "LICENSE"}
classifiers = ["License :: OSI Approved :: MIT License"]
Expand Down
16 changes: 16 additions & 0 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

from predictionguard import PredictionGuard


def test_audio_transcribe_success():
test_client = PredictionGuard()

response = test_client.audio.transcriptions.create(
model="whisper-3-base",
file="fixtures/test_audio.wav"
)

print(response)

assert len(response["text"]) > 0
13 changes: 13 additions & 0 deletions tests/test_documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from predictionguard import PredictionGuard


def test_documents_extract_success():
test_client = PredictionGuard()

response = test_client.documents.extract.create(
file="fixtures/test_pdf.pdf"
)

assert len(response["contents"]) > 0