Skip to content

Commit

Permalink
add temp and stop args (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
Infrared1029 committed Jun 6, 2024
1 parent 8346ef1 commit ce6dd4a
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def generate( # noqa: WPS234, WPS211
user_prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
messages: Optional[List[Dict[str, str]]] = None,
max_tokens: Optional[int] = None,
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
stream: bool = False,
) -> Union[Generator[str, None, None], str]: # noqa: DAR101, DAR201, DAR401
"""Generate content using the Unify API.
Expand All @@ -141,8 +143,15 @@ def generate( # noqa: WPS234, WPS211
messages (List[Dict[str, str]]): A list of dictionaries containing the
conversation history. If provided, user_prompt must be None.
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
max_tokens (Optional[int]): The max number of output tokens.
Defaults to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
stop (Optional[List[str]]): Up to 4 sequences where the API will stop generating further tokens.
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Expand All @@ -159,7 +168,6 @@ def generate( # noqa: WPS234, WPS211
contents = []
if system_prompt:
contents.append({"role": "system", "content": system_prompt})

if user_prompt:
contents.append({"role": "user", "content": user_prompt})
elif messages:
Expand All @@ -168,8 +176,14 @@ def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provider either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens)
return self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens)
return self._generate_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)
return self._generate_non_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)

def get_credit_balance(self) -> float:
# noqa: DAR201, DAR401
Expand Down Expand Up @@ -201,13 +215,17 @@ def _generate_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
) -> Generator[str, None, None]:
try:
chat_completion = self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
stream=True,
)
for chunk in chat_completion:
Expand All @@ -222,13 +240,17 @@ def _generate_non_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
) -> str:
try:
chat_completion = self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
stream=False,
)
self.set_provider(
Expand Down Expand Up @@ -388,6 +410,8 @@ async def generate( # noqa: WPS234, WPS211
system_prompt: Optional[str] = None,
messages: Optional[List[Dict[str, str]]] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
stream: bool = False,
) -> Union[AsyncGenerator[str, None], str]: # noqa: DAR101, DAR201, DAR401
"""Generate content asynchronously using the Unify API.
Expand All @@ -405,6 +429,13 @@ async def generate( # noqa: WPS234, WPS211
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
stop (Optional[List[str]]): Up to 4 sequences where the API will stop generating further tokens.
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Defaults to False.
Expand All @@ -429,20 +460,24 @@ async def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provide either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens)
return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens)
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)
return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)

async def _generate_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
) -> AsyncGenerator[str, None]:
try:
async_stream = await self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
stream=True,
)
async for chunk in async_stream: # type: ignore[union-attr]
Expand All @@ -456,12 +491,16 @@ async def _generate_non_stream(
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0,
stop: Optional[List[str]] = None,
) -> str:
try:
async_response = await self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
stream=False,
)
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
Expand Down

0 comments on commit ce6dd4a

Please sign in to comment.