Skip to content

Commit

Permalink
Added system prompt to the .generate function and updated the README
Browse files Browse the repository at this point in the history
  • Loading branch information
hello-fri-end committed Mar 29, 2024
1 parent ab4fa88 commit 4d21bfb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ response = unify.generate(messages="Hello Llama! Who was Isaac Newton?", model="

Here, `response` is a string containing the model's output.

You can influence the model's persona using the `system_prompt` argument in the `.generate` function.

```python
response = unify.generate(messages="Hello Llama! Who was Isaac Newton?", system_prompt="You should always talk in rhymes", model="llama-2-13b-chat", provider="anyscale")
```

### Supported Models
The list of supported models and providers is available in [the platform](https://unify.ai/hub).

Expand Down
40 changes: 33 additions & 7 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def __init__(
except openai.OpenAIError as e:
raise UnifyError(f"Failed to initialize Unify client: {str(e)}")

def generate( # noqa: WPS234
def generate( # noqa: WPS234, WPS211
self,
messages: Union[str, List[Dict[str, str]]],
system_prompt: Optional[str] = None,
model: str = "llama-2-13b-chat",
provider: str = "anyscale",
stream: bool = False,
Expand All @@ -54,6 +55,8 @@ def generate( # noqa: WPS234
Args:
messages (Union[str, List[Dict[str, str]]]): A single prompt as a
string or a dictionary containing the conversation history.
system_prompt (Optinal[str]): An optional string containing the
system prompt.
model (str): The name of the model. Defaults to "llama-2-13b-chat".
provider (str): The provider of the model. Defaults to "anyscale".
stream (bool): If True, generates content as a stream.
Expand All @@ -69,9 +72,19 @@ def generate( # noqa: WPS234
UnifyError: If an error occurs during content generation.
"""
if isinstance(messages, str):
contents = [{"role": "user", "content": messages}]
if system_prompt is None:
contents = [{"role": "user", "content": messages}]
else:
contents = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": messages},
]
else:
contents = messages
if system_prompt is None:
contents = messages
else:
contents = [{"role": "system", "content": system_prompt}]
contents.extend(messages)

if stream:
return self._generate_stream(contents, model, provider)
Expand Down Expand Up @@ -140,9 +153,10 @@ def __init__(
except openai.APIStatusError as e:
raise UnifyError(f"Failed to initialize Unify client: {str(e)}")

async def generate( # noqa: WPS234
async def generate( # noqa: WPS234, WPS211
self,
messages: Union[str, List[Dict[str, str]]],
system_prompt: Optional[str] = None,
model: str = "llama-2-13b-chat",
provider: str = "anyscale",
stream: bool = False,
Expand All @@ -152,6 +166,9 @@ async def generate( # noqa: WPS234
Args:
messages (Union[str, List[Dict[str, str]]]): A single prompt as a string
or a dictionary containing the conversation history.
system_prompt (Optinal[str]): An optional string containing the
system prompt.
when messages is a string.
model (str): The name of the model.
provider (str): The provider of the model.
stream (bool): If True, generates content as a stream.
Expand All @@ -167,10 +184,19 @@ async def generate( # noqa: WPS234
UnifyError: If an error occurs during content generation.
"""
if isinstance(messages, str):
contents = [{"role": "user", "content": messages}]
if system_prompt is None:
contents = [{"role": "user", "content": messages}]
else:
contents = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": messages},
]
else:
contents = messages

if system_prompt is None:
contents = messages
else:
contents = [{"role": "system", "content": system_prompt}]
contents.extend(messages)
if stream:
return self._generate_stream(contents, model, provider)
return await self._generate_non_stream(contents, model, provider)
Expand Down

0 comments on commit 4d21bfb

Please sign in to comment.