Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added fixtures/test_audio.wav
Binary file not shown.
6 changes: 6 additions & 0 deletions fixtures/test_csv.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
John,Doe,120 jefferson st.,Riverside, NJ, 08075
Jack,McGinnis,220 hobo Av.,Phila, PA,09119
"John ""Da Man""",Repici,120 Jefferson St.,Riverside, NJ,08075
Stephen,Tyler,"7452 Terrace ""At the Plaza"" road",SomeTown,SD, 91234
,Blankman,,SomeTown, SD, 00298
"Joan ""the bone"", Anne",Jet,"9th, at Terrace plc",Desert City,CO,00123
Binary file added fixtures/test_pdf.pdf
Binary file not shown.
14 changes: 11 additions & 3 deletions predictionguard/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import requests
from typing import Optional

from .src.audio import Audio
from .src.chat import Chat
from .src.completions import Completions
from .src.documents import Documents
from .src.embeddings import Embeddings
from .src.rerank import Rerank
from .src.tokenize import Tokenize
Expand All @@ -17,9 +19,9 @@
from .version import __version__

__all__ = [
"PredictionGuard", "Chat", "Completions", "Embeddings", "Rerank",
"Tokenize", "Translate", "Factuality", "Toxicity", "Pii", "Injection",
"Models"
"PredictionGuard", "Chat", "Completions", "Embeddings",
"Audio", "Documents", "Rerank", "Tokenize", "Translate",
"Factuality", "Toxicity", "Pii", "Injection", "Models"
]

class PredictionGuard:
Expand Down Expand Up @@ -65,6 +67,12 @@ def __init__(
self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
"""Embedding generates chat completions based on a conversation history."""

self.audio: Audio = Audio(self.api_key, self.url)
"""Audio allows for the transcription of audio files."""

self.documents: Documents = Documents(self.api_key, self.url)
"""Documents allows you to extract text from various document file types."""

self.rerank: Rerank = Rerank(self.api_key, self.url)
"""Rerank sorts text inputs by semantic relevance to a specified query."""

Expand Down
98 changes: 98 additions & 0 deletions predictionguard/src/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json

import requests
from typing import Any, Dict, Optional

from ..version import __version__


class Audio:
"""Audio generates a response based on audio data.

Usage::

import os
import json

from predictionguard import PredictionGuard

# Set your Prediction Guard token as an environmental variable.
os.environ["PREDICTIONGUARD_API_KEY"] = "<api key>"

client = PredictionGuard()

result = client.audio.transcriptions.create(
model="whisper-3-large-instruct", file=sample_audio.wav
)

print(json.dumps(result, sort_keys=True, indent=4, separators=(",", ": ")))
"""

def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url)

class AudioTranscriptions:
def __init__(self, api_key, url):
self.api_key = api_key
self.url = url

def create(
self,
model: str,
file: str
) -> Dict[str, Any]:
"""
Creates a audio transcription request to the Prediction Guard /audio/transcriptions API

:param model: The model to use
:param file: Audio file to be transcribed
:result: A dictionary containing the transcribed text.
"""

# Create a list of tuples, each containing all the parameters for
# a call to _transcribe_audio
args = (model, file)

# Run _transcribe_audio
choices = self._transcribe_audio(*args)
return choices

def _transcribe_audio(self, model, file):
"""
Function to transcribe an audio file.
"""

headers = {
"Authorization": "Bearer " + self.api_key,
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

with open(file, "rb") as audio_file:
files = {"file": (file, audio_file, "audio/wav")}
data = {"model": model}

response = requests.request(
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
)

# If the request was successful, print the proxies.
if response.status_code == 200:
ret = response.json()
return ret
elif response.status_code == 429:
raise ValueError(
"Could not connect to Prediction Guard API. "
"Too many requests, rate limit or quota exceeded."
)
else:
# Check if there is a json body in the response. Read that in,
# print out the error field in the json body, and raise an exception.
err = ""
try:
err = response.json()["error"]
except Exception:
pass
raise ValueError("Could not transcribe the audio file. " + err)
65 changes: 57 additions & 8 deletions predictionguard/src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Chat:
{
"role": "user",
"content": "Haha. Good one."
},
}
]

result = client.chat.completions.create(
Expand All @@ -69,15 +69,36 @@ def __init__(self, api_key, url):
def create(
self,
model: str,
messages: Union[str, List[Dict[str, Any]]],
messages: Union[
str, List[
Dict[str, Any]
]
],
input: Optional[Dict[str, Any]] = None,
output: Optional[Dict[str, Any]] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[
Dict[str, int]
] = None,
max_completion_tokens: Optional[int] = 100,
max_tokens: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
stop: Optional[
Union[
str, List[str]
]
] = None,
stream: Optional[bool] = False,
temperature: Optional[float] = 1.0,
tool_choice: Optional[Union[
str, Dict[
str, Dict[str, str]
]
]] = "none",
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
top_p: Optional[float] = 0.99,
top_k: Optional[float] = 50,
stream: Optional[bool] = False,
) -> Dict[str, Any]:
"""
Creates a chat request for the Prediction Guard /chat API.
Expand All @@ -86,11 +107,18 @@ def create(
:param messages: The content of the call, an array of dictionaries containing a role and content.
:param input: A dictionary containing the PII and injection arguments.
:param output: A dictionary containing the consistency, factuality, and toxicity arguments.
:param frequency_penalty: The frequency penalty to use.
:param logit_bias: The logit bias to use.
:param max_completion_tokens: The maximum amount of tokens the model should return.
:param parallel_tool_calls: The parallel tool calls to use.
:param presence_penalty: The presence penalty to use.
:param stop: The completion stopping criteria.
:param stream: Option to stream the API response
:param temperature: The consistency of the model responses to the same prompt. The higher the more consistent.
:param tool_choice: The tool choice to use.
:param tools: Options to pass to the tool choice.
:param top_p: The sampling for the model to use.
:param top_k: The Top-K sampling for the model to use.
:param stream: Option to stream the API response
:return: A dictionary containing the chat response.
"""

Expand All @@ -110,11 +138,18 @@ def create(
messages,
input,
output,
frequency_penalty,
logit_bias,
max_completion_tokens,
parallel_tool_calls,
presence_penalty,
stop,
stream,
temperature,
tool_choice,
tools,
top_p,
top_k,
stream,
top_k
)

# Run _generate_chat
Expand All @@ -128,11 +163,18 @@ def _generate_chat(
messages,
input,
output,
frequency_penalty,
logit_bias,
max_completion_tokens,
parallel_tool_calls,
presence_penalty,
stop,
stream,
temperature,
tool_choice,
tools,
top_p,
top_k,
stream,
):
"""
Function to generate a single chat response.
Expand Down Expand Up @@ -257,11 +299,18 @@ def stream_generator(url, headers, payload, stream):
payload_dict = {
"model": model,
"messages": messages,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"max_completion_tokens": max_completion_tokens,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"stop": stop,
"stream": stream,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_p": top_p,
"top_k": top_k,
"stream": stream,
}

if input:
Expand Down
Loading
Loading