Skip to content

Commit 3758f28

Browse files
refactoring
1 parent 1e48d85 commit 3758f28

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

core/CAIAssistant.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, promptsFolder=None):
1111
promptsFolder = os.path.join(os.path.dirname(__file__), '../prompts')
1212
self._LLM = ChatOpenAI(model="gpt-3.5-turbo")
1313

14-
self._translateShallow = LLMChain(
14+
self._translateShallowQuery = LLMChain(
1515
llm=self._LLM,
1616
prompt=ChatPromptTemplate.from_messages([
1717
HumanMessagePromptTemplate(
@@ -22,7 +22,7 @@ def __init__(self, promptsFolder=None):
2222
),
2323
]),
2424
)
25-
self._translateDeep = LLMChain(
25+
self._translateDeepQuery = LLMChain(
2626
llm=self._LLM,
2727
prompt=ChatPromptTemplate.from_messages([
2828
HumanMessagePromptTemplate(
@@ -60,37 +60,53 @@ def _executePrompt(self, prompt, variables):
6060
res['Flags'] = flags
6161
return res
6262

63-
def translate(self, text, fastTranslation, language):
64-
# run shallow translation
63+
def _translateShallow(self, text, translation, language):
6564
res = self._executePrompt(
66-
self._translateShallow,
65+
self._translateShallowQuery,
6766
{
6867
'UserInput': text,
69-
'FastTranslation': fastTranslation,
68+
'FastTranslation': translation,
7069
'Language': language
7170
}
7271
)
72+
translation = res['Translation']
7373
flags = res['Flags']
7474
totalIssues = sum([int(v) for v in flags.values()])
7575
if totalIssues < 2:
76-
yield res['Translation']
77-
return # all ok, no need to run deep translation
78-
yield res['Translation'], res.get('Notification', '')
79-
80-
# run deep translation
81-
inputLanguage = res.get('Input language', 'unknown')
76+
return translation # all ok, no need to run deep translation
77+
78+
return res, translation, res.get('Notification', '')
79+
80+
def _translateDeep(self, text, translation, language, inputLanguage, flags):
8281
# extract first word from input language, can be separated by space, comma, etc.,
8382
inputLanguage = re.split(r'[\s,]+', inputLanguage)[0]
8483
inputLanguage = inputLanguage.strip().capitalize()
84+
8585
res = self._executePrompt(
86-
self._translateDeep,
86+
self._translateDeepQuery,
8787
{
8888
'UserInput': text,
89-
'FastTranslation': res['Translation'], # use shallow translation as reference
89+
'FastTranslation': translation,
9090
'Language': language,
9191
'InputLanguage': inputLanguage,
9292
'Flags': ', '.join([k for k, v in flags.items() if v])
9393
}
9494
)
95-
yield res['Translation']
95+
return res['Translation']
96+
97+
def translate(self, text, fastTranslation, language):
98+
# run shallow translation
99+
res = self._translateShallow(text=text, translation=fastTranslation, language=language)
100+
if isinstance(res, str):
101+
yield res
102+
return # all ok, no need to run deep translation
103+
104+
raw, translation, notification = res
105+
yield translation, notification # yield shallow translation with notification
106+
# run deep translation
107+
yield self._translateDeep(
108+
text=text, translation=translation, language=language,
109+
inputLanguage=raw.get('Input language', 'unknown'),
110+
flags=raw['Flags']
111+
)
96112
return

0 commit comments

Comments
 (0)