Skip to content

Commit 36a2b69

Browse files
"Refine" button
1 parent 4f4895e commit 36a2b69

File tree

3 files changed

+93
-25
lines changed

3 files changed

+93
-25
lines changed

core/CAIAssistant.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from langchain.prompts import PromptTemplate
66
from langchain.chains import LLMChain
77

8+
from collections import namedtuple
9+
10+
CAITranslationResult = namedtuple(
11+
'CAITranslationResult',
12+
['translation', 'pending', 'text', 'language', 'InputLanguage', 'Flags']
13+
)
14+
815
class CAIAssistant:
916
def __init__(self, promptsFolder=None, openai_api_key=None):
1017
if promptsFolder is None:
@@ -56,6 +63,8 @@ def _extractParts(self, text):
5663
return {k: v for k, v in tmp}
5764

5865
def _executePrompt(self, prompt, variables):
66+
if not self._connected: raise Exception('Not connected to API')
67+
5968
rawPrompt = prompt.prompt.format_prompt(**variables).to_string()
6069
logging.info('Raw prompt: ' + rawPrompt)
6170
res = prompt.run(variables)
@@ -106,18 +115,35 @@ def _translateDeep(self, text, translation, language, inputLanguage, flags):
106115
return res['Translation']
107116

108117
def translate(self, text, fastTranslation, language):
109-
if not self._connected:
110-
raise Exception('Not connected to API')
111118
# run shallow translation
112119
raw, translation, done = self._translateShallow(
113120
text=text, translation=fastTranslation, language=language
114121
)
115-
yield translation, not done
116-
if done: return
117-
# run deep translation
118-
yield self._translateDeep(
119-
text=text, translation=translation, language=language,
120-
inputLanguage=raw.get('Input language', 'unknown'),
121-
flags=raw['Flags']
122-
), False # no more pending translations
123-
return
122+
translationResult = CAITranslationResult(
123+
translation=translation, pending=not done,
124+
text=text, language=language,
125+
InputLanguage=raw.get('Input language', 'unknown'),
126+
Flags=raw['Flags'],
127+
)
128+
yield translationResult
129+
if not done: # run deep translation
130+
yield self.refine(translationResult)
131+
return
132+
133+
def refine(self, previousTranslation: CAITranslationResult):
134+
res = self._translateDeep(
135+
text=previousTranslation.text,
136+
translation=previousTranslation.translation,
137+
language=previousTranslation.language,
138+
inputLanguage=previousTranslation.InputLanguage,
139+
flags=previousTranslation.Flags,
140+
)
141+
142+
translationResult = CAITranslationResult(
143+
translation=res,
144+
text=previousTranslation.text, language=previousTranslation.language,
145+
InputLanguage=previousTranslation.InputLanguage,
146+
Flags=previousTranslation.Flags,
147+
pending=False, # no more steps
148+
)
149+
return translationResult

core/worker.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ def __init__(self, events):
1111
self._forceTranslateEvent = threading.Event()
1212
self._translatorFast = Translator(service_urls=['translate.google.com'])
1313
self._assistant = CAIAssistant()
14+
15+
# TODO: Fix this hack. It's not thread safe. Also check related issues with UI during refinement.
16+
self._refinement = None
1417
return
1518

1619
def run(self):
@@ -22,6 +25,10 @@ def run(self):
2225
isForceTranslate = self._forceTranslateEvent.wait(5)
2326
self._forceTranslateEvent.clear()
2427

28+
if self._refinement is not None:
29+
self._doRefine()
30+
continue
31+
2532
userInput = self._events.userInput()
2633
text = userInput['text']
2734
language = userInput['language']
@@ -63,22 +70,23 @@ def _performTranslate(self, text, force, language):
6370
text = text.strip()
6471
try:
6572
self._events.startTranslate(force)
66-
fastText = self._fastTranslate(text, language=language)
73+
fastText = self._fastTranslate(text, languageCode=language['code'])
6774
self._events.fastTranslated(fastText)
6875
if not force: return
6976

7077
translationProcess = self._fullTranslate(text, fastTranslation=fastText, language=language)
71-
for fullText, hasMore in translationProcess:
72-
self._events.fullTranslated(fullText, pending=hasMore)
78+
for translationResult in translationProcess:
79+
self._events.fullTranslated(translationResult)
7380
if self._forceTranslateEvent.is_set(): break # stop if force another translate
7481
continue
7582
finally:
7683
self._events.endTranslate()
7784
return
7885

79-
def _fastTranslate(self, text, language):
86+
@lru_cache(maxsize=20)
87+
def _fastTranslate(self, text, languageCode):
8088
if 0 == len(text): return ""
81-
translated = self._translatorFast.translate(text, dest=language['code'])
89+
translated = self._translatorFast.translate(text, dest=languageCode)
8290
return translated.text
8391

8492
def _fullTranslate(self, text, fastTranslation, language):
@@ -90,9 +98,7 @@ def _fullTranslate(self, text, fastTranslation, language):
9098
text, language=language['name'],
9199
fastTranslation=fastTranslation,
92100
)
93-
for translation in translationProcess:
94-
yield translation
95-
continue
101+
for translation in translationProcess: yield translation
96102
return
97103

98104
@lru_cache(maxsize=None)
@@ -108,4 +114,19 @@ def _updateLocalization(self, languageName, languageCode):
108114
def bindAPI(self, key):
109115
self._assistant.bindAPI(key)
110116
return
111-
117+
118+
def refine(self, previousTranslation):
119+
self._refinement = previousTranslation
120+
self.forceTranslate()
121+
return
122+
123+
def _doRefine(self):
124+
assert self._refinement is not None
125+
previousTranslation = self._refinement
126+
self._refinement = None
127+
try:
128+
res = self._assistant.refine(previousTranslation)
129+
self._events.fullTranslated(res)
130+
except Exception as e:
131+
self._events.error(e)
132+
return

main.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import dotenv
1010

1111
# main app
12-
# TODO: figure out how to enforce deep translation via UI
1312
# TODO: add translation history
1413
# TODO: add simple translation diff
1514
class App(tk.Frame):
1615
def __init__(self, master, languages, configs):
1716
super().__init__(master)
17+
self._lastAIResult = None
1818
self._configs = configs
1919
self._localizationMap = {}
2020
# predefine messages
@@ -91,6 +91,13 @@ def _UI_fullTranslation(self, owner):
9191
textvariable=self._localization("Slow and improved translation (ChatGPT/AI):")
9292
)
9393
label.pack(side="top", fill=tk.X)
94+
# Button "Refine" to force deep translation, disabled by default
95+
self._refineBtn = btn = tk.Button(
96+
owner, state=tk.DISABLED,
97+
command=self.onRefine,
98+
textvariable=self._localization("Refine"),
99+
)
100+
btn.pack(side="bottom", padx=5, pady=5, anchor="e")
94101

95102
self._fullOutputText = tkst.ScrolledText(owner, wrap=tk.WORD)
96103
self._fullOutputText.pack(side="top", fill=tk.BOTH, expand=tk.YES)
@@ -162,11 +169,14 @@ def fastTranslated(self, text):
162169
self._fastOutputText.insert(tk.END, text)
163170
return
164171

165-
def fullTranslated(self, text, pending):
172+
def fullTranslated(self, translationResult):
173+
self._lastAIResult = translationResult
174+
self._refineBtn.config(state=tk.DISABLED if translationResult.pending else tk.NORMAL)
175+
166176
self._fullOutputText.delete("1.0", tk.END)
167-
self._fullOutputText.insert(tk.END, text)
177+
self._fullOutputText.insert(tk.END, translationResult.translation)
168178

169-
if pending:
179+
if translationResult.pending:
170180
notification = self._localization(
171181
"Translation is not accurate and will be updated soon."
172182
).get()
@@ -185,7 +195,10 @@ def onSelectLanguage(self, event):
185195
language = self._language.get()
186196
code = next((code for code, name in self._languages.items() if name == language), None)
187197
if code is None: return
188-
198+
# discard AI result and disable refine button
199+
self._lastAIResult = None
200+
self._refineBtn.config(state=tk.DISABLED)
201+
# update other stuff
189202
self._currentLanguage = code
190203
self._configs['language'] = code
191204
self._worker.forceTranslate() # hack to force translation
@@ -209,6 +222,14 @@ def onSwitchAPIKey(self):
209222
return
210223

211224
def configs(self): return self._configs
225+
226+
def onRefine(self, event=None):
227+
# check if refine button is enabled
228+
if tk.DISABLED == self._refineBtn['state']: return 'break' # prevent unwanted action
229+
230+
self._worker.refine(self._lastAIResult)
231+
self._refineBtn.config(state=tk.DISABLED)
232+
return 'break'
212233
# End of class
213234

214235
def main():

0 commit comments

Comments
 (0)