Skip to content

Add model parameter #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/example_script.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

import translation_agent as ta


if __name__ == "__main__":
source_lang, target_lang, country = "English", "Spanish", "Mexico"

@@ -17,6 +16,7 @@
print(f"Source text:\n\n{source_text}\n------------\n")

translation = ta.translate(
model_name="gpt-3.5-turbo",
source_lang=source_lang,
target_lang=target_lang,
source_text=source_text,
25 changes: 15 additions & 10 deletions src/translation_agent/utils.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
def get_completion(
prompt: str,
system_message: str = "You are a helpful assistant.",
model: str = "gpt-4-turbo",
model_name: str = "gpt-4-turbo",
temperature: float = 0.3,
json_mode: bool = False,
) -> Union[str, dict]:
@@ -32,7 +32,7 @@ def get_completion(
prompt (str): The user's prompt or query.
system_message (str, optional): The system message to set the context for the assistant.
Defaults to "You are a helpful assistant.".
model (str, optional): The name of the OpenAI model to use for generating the completion.
model_name (str, optional): The name of the OpenAI model to use for generating the completion.
Defaults to "gpt-4-turbo".
temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
Defaults to 0.3.
@@ -47,7 +47,7 @@ def get_completion(

if json_mode:
response = client.chat.completions.create(
model=model,
model=model_name,
temperature=temperature,
top_p=1,
response_format={"type": "json_object"},
@@ -59,7 +59,7 @@ def get_completion(
return response.choices[0].message.content
else:
response = client.chat.completions.create(
model=model,
model=model_name,
temperature=temperature,
top_p=1,
messages=[
@@ -71,12 +71,14 @@ def get_completion(


def one_chunk_initial_translation(
source_lang: str, target_lang: str, source_text: str
source_lang: str, target_lang: str, source_text: str, model_name: str = "gpt-4-turbo"
) -> str:
"""
Translate the entire text as one chunk using an LLM.

Args:
model_name (str): The name of the OpenAI model to use for generating the completion.
Defaults to "gpt-4-turbo".
source_lang (str): The source language of the text.
target_lang (str): The target language for translation.
source_text (str): The text to be translated.
@@ -95,7 +97,7 @@ def one_chunk_initial_translation(

prompt = translation_prompt.format(source_text=source_text)

translation = get_completion(prompt, system_message=system_message)
translation = get_completion(prompt, system_message=system_message, model_name=model_name)

return translation

@@ -238,7 +240,7 @@ def one_chunk_improve_translation(


def one_chunk_translate_text(
source_lang: str, target_lang: str, source_text: str, country: str = ""
source_lang: str, target_lang: str, source_text: str, country: str = "", model_name: str = ""
) -> str:
"""
Translate a single chunk of text from the source language to the target language.
@@ -248,6 +250,8 @@ def one_chunk_translate_text(
2. Reflect on the initial translation and generate an improved translation.

Args:
model_name (str): The name of the OpenAI model to use for generating the completion.
Defaults to "gpt-4-turbo".
source_lang (str): The source language of the text.
target_lang (str): The target language for the translation.
source_text (str): The text to be translated.
@@ -256,7 +260,7 @@ def one_chunk_translate_text(
str: The improved translation of the source text.
"""
translation_1 = one_chunk_initial_translation(
source_lang, target_lang, source_text
source_lang, target_lang, source_text, model_name
)

reflection = one_chunk_reflect_on_translation(
@@ -647,6 +651,7 @@ def translate(
source_text,
country,
max_tokens=MAX_TOKENS_PER_CHUNK,
model_name="gpt-4",
):
"""Translate the source_text from source_lang to target_lang."""

@@ -658,7 +663,7 @@ def translate(
ic("Translating text as single chunk")

final_translation = one_chunk_translate_text(
source_lang, target_lang, source_text, country
source_lang, target_lang, source_text, country, model_name
)

return final_translation
@@ -673,7 +678,7 @@ def translate(
ic(token_size)

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
model_name="gpt-4",
model_name=model_name,
chunk_size=token_size,
chunk_overlap=0,
)