diff --git a/fixtures/test_audio.wav b/fixtures/test_audio.wav new file mode 100644 index 0000000..3184d37 Binary files /dev/null and b/fixtures/test_audio.wav differ diff --git a/fixtures/test_csv.csv b/fixtures/test_csv.csv new file mode 100644 index 0000000..e7bba0d --- /dev/null +++ b/fixtures/test_csv.csv @@ -0,0 +1,6 @@ +John,Doe,120 jefferson st.,Riverside, NJ, 08075 +Jack,McGinnis,220 hobo Av.,Phila, PA,09119 +"John ""Da Man""",Repici,120 Jefferson St.,Riverside, NJ,08075 +Stephen,Tyler,"7452 Terrace ""At the Plaza"" road",SomeTown,SD, 91234 +,Blankman,,SomeTown, SD, 00298 +"Joan ""the bone"", Anne",Jet,"9th, at Terrace plc",Desert City,CO,00123 diff --git a/fixtures/test_pdf.pdf b/fixtures/test_pdf.pdf new file mode 100644 index 0000000..c01805e Binary files /dev/null and b/fixtures/test_pdf.pdf differ diff --git a/predictionguard/client.py b/predictionguard/client.py index c4762de..0bc9e9f 100644 --- a/predictionguard/client.py +++ b/predictionguard/client.py @@ -3,8 +3,10 @@ import requests from typing import Optional +from .src.audio import Audio from .src.chat import Chat from .src.completions import Completions +from .src.documents import Documents from .src.embeddings import Embeddings from .src.rerank import Rerank from .src.tokenize import Tokenize @@ -17,9 +19,9 @@ from .version import __version__ __all__ = [ - "PredictionGuard", "Chat", "Completions", "Embeddings", "Rerank", - "Tokenize", "Translate", "Factuality", "Toxicity", "Pii", "Injection", - "Models" + "PredictionGuard", "Chat", "Completions", "Embeddings", + "Audio", "Documents", "Rerank", "Tokenize", "Translate", + "Factuality", "Toxicity", "Pii", "Injection", "Models" ] class PredictionGuard: @@ -65,6 +67,12 @@ def __init__( self.embeddings: Embeddings = Embeddings(self.api_key, self.url) """Embedding generates chat completions based on a conversation history.""" + self.audio: Audio = Audio(self.api_key, self.url) + """Audio allows for the transcription of audio files.""" + + self.documents: Documents = Documents(self.api_key, self.url) + """Documents allows you to extract text from various document file types.""" + self.rerank: Rerank = Rerank(self.api_key, self.url) """Rerank sorts text inputs by semantic relevance to a specified query.""" diff --git a/predictionguard/src/audio.py b/predictionguard/src/audio.py new file mode 100644 index 0000000..a7348c3 --- /dev/null +++ b/predictionguard/src/audio.py @@ -0,0 +1,98 @@ +import json + +import requests +from typing import Any, Dict, Optional + +from ..version import __version__ + + +class Audio: + """Audio generates a response based on audio data. + + Usage:: + + import os + import json + + from predictionguard import PredictionGuard + + # Set your Prediction Guard token as an environmental variable. + os.environ["PREDICTIONGUARD_API_KEY"] = "" + + client = PredictionGuard() + + result = client.audio.transcriptions.create( + model="whisper-3-large-instruct", file=sample_audio.wav + ) + + print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": "))) + """ + + def __init__(self, api_key, url): + self.api_key = api_key + self.url = url + + self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url) + +class AudioTranscriptions: + def __init__(self, api_key, url): + self.api_key = api_key + self.url = url + + def create( + self, + model: str, + file: str + ) -> Dict[str, Any]: + """ + Creates a audio transcription request to the Prediction Guard /audio/transcriptions API + + :param model: The model to use + :param file: Audio file to be transcribed + :result: A dictionary containing the transcribed text. + """ + + # Create a list of tuples, each containing all the parameters for + # a call to _transcribe_audio + args = (model, file) + + # Run _transcribe_audio + choices = self._transcribe_audio(*args) + return choices + + def _transcribe_audio(self, model, file): + """ + Function to transcribe an audio file. + """ + + headers = { + "Authorization": "Bearer " + self.api_key, + "User-Agent": "Prediction Guard Python Client: " + __version__, + } + + with open(file, "rb") as audio_file: + files = {"file": (file, audio_file, "audio/wav")} + data = {"model": model} + + response = requests.request( + "POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data + ) + + # If the request was successful, print the proxies. + if response.status_code == 200: + ret = response.json() + return ret + elif response.status_code == 429: + raise ValueError( + "Could not connect to Prediction Guard API. " + "Too many requests, rate limit or quota exceeded." + ) + else: + # Check if there is a json body in the response. Read that in, + # print out the error field in the json body, and raise an exception. + err = "" + try: + err = response.json()["error"] + except Exception: + pass + raise ValueError("Could not transcribe the audio file. " + err) \ No newline at end of file diff --git a/predictionguard/src/completions.py b/predictionguard/src/completions.py index 6b3fe12..854e830 100644 --- a/predictionguard/src/completions.py +++ b/predictionguard/src/completions.py @@ -22,6 +22,7 @@ def create( prompt: Union[str, List[str]], input: Optional[Dict[str, Any]] = None, output: Optional[Dict[str, Any]] = None, + echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[Dict[str, int]] = None, max_completion_tokens: Optional[int] = 100, @@ -40,6 +41,7 @@ def create( :param prompt: The prompt(s) to generate completions for. :param input: A dictionary containing the PII and injection arguments. :param output: A dictionary containing the consistency, factuality, and toxicity arguments. + :param echo: A boolean indicating whether to echo the prompt(s) to the output. :param frequency_penalty: The frequency penalty to use. :param logit_bias: The logit bias to use. :param max_completion_tokens: The maximum number of tokens to generate in the completion(s). @@ -68,6 +70,7 @@ def create( prompt, input, output, + echo, frequency_penalty, logit_bias, max_completion_tokens, @@ -91,6 +94,7 @@ def _generate_completion( prompt, input, output, + echo, frequency_penalty, logit_bias, max_completion_tokens, @@ -165,6 +169,7 @@ def stream_generator(url, headers, payload, stream): payload_dict = { "model": model, "prompt": prompt, + "echo": echo, "frequency_penalty": frequency_penalty, "logit_bias": logit_bias, "max_completion_tokens": max_completion_tokens, diff --git a/predictionguard/src/documents.py b/predictionguard/src/documents.py new file mode 100644 index 0000000..69b3ae9 --- /dev/null +++ b/predictionguard/src/documents.py @@ -0,0 +1,89 @@ +import json +from pyexpat import model + +import requests +from typing import Any, Dict, Optional + +from ..version import __version__ + + +class Documents: + """Documents allows you to extract text from various document file types. + + Usage:: + + from predictionguard import PredictionGuard + + # Set your Prediction Guard token as an environmental variable. + os.environ["PREDICTIONGUARD_API_KEY"] = "" + + client = PredictionGuard() + + response = client.documents.extract.create( + file="sample.pdf" + ) + + print(json.dumps(response, sort_keys=True, indent=4, separators=(",", ": "))) + """ + + def __init__(self, api_key, url): + self.api_key = api_key + self.url = url + + self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url) + +class DocumentsExtract: + def __init__(self, api_key, url): + self.api_key = api_key + self.url = url + + def create( + self, + file: str + ) -> Dict[str, Any]: + """ + Creates a documents request to the Prediction Guard /documents/extract API + + :param file: Document to be parsed + :result: A dictionary containing the title, content, and length of the document. + """ + + # Run _extract_documents + choices = self._extract_documents(file) + return choices + + def _extract_documents(self, file): + """ + Function to extract a document. + """ + + headers = { + "Authorization": "Bearer " + self.api_key, + "User-Agent": "Prediction Guard Python Client: " + __version__, + } + + with open(file, "rb") as doc_file: + files = {"file": (file, doc_file)} + + response = requests.request( + "POST", self.url + "/documents/extract", headers=headers, files=files + ) + + # If the request was successful, print the proxies. + if response.status_code == 200: + ret = response.json() + return ret + elif response.status_code == 429: + raise ValueError( + "Could not connect to Prediction Guard API. " + "Too many requests, rate limit or quota exceeded." + ) + else: + # Check if there is a json body in the response. Read that in, + # print out the error field in the json body, and raise an exception. + err = "" + try: + err = response.json()["error"] + except Exception: + pass + raise ValueError("Could not extract document. " + err) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 31718e0..da06461 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,10 @@ build-backend = "flit_core.buildapi" [project] name = "predictionguard" -authors = [{name = "Daniel Whitenack", email = "dan@predictionguard.com"}] +authors = [ + {name = "Daniel Whitenack", email = "dan@predictionguard.com"}, + {name = "Jacob Mansdorfer", email = "jacob@predictionguard.com"} + ] readme = "README.md" license = {file = "LICENSE"} classifiers = ["License :: OSI Approved :: MIT License"] diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000..cdb08d6 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,16 @@ +import os + +from predictionguard import PredictionGuard + + +def test_audio_transcribe_success(): + test_client = PredictionGuard() + + response = test_client.audio.transcriptions.create( + model="whisper-3-base", + file="fixtures/test_audio.wav" + ) + + print(response) + + assert len(response["text"]) > 0 \ No newline at end of file diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 0000000..c877c0e --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,13 @@ +import os + +from predictionguard import PredictionGuard + + +def test_documents_extract_success(): + test_client = PredictionGuard() + + response = test_client.documents.extract.create( + file="fixtures/test_pdf.pdf" + ) + + assert len(response["contents"]) > 0 \ No newline at end of file