Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user experience when api key is not valid #172

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/transcribe/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sdk import transcriber_models as tm # noqa: E402 pylint: disable=C0413


def create_responder(provider_name: str, config, convo, save_to_file: bool,
def create_responder(provider_name: str, config, convo, save_to_file: bool,
response_file_name: str):
"""Creates a responder / Inference provider object based on input parameters
"""
Expand Down
26 changes: 25 additions & 1 deletion app/transcribe/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import interactions # noqa: E402 pylint: disable=C0413
from tsutils import utilities, duration, configuration # noqa: E402 pylint: disable=C0413
from sdk import audio_recorder as ar # noqa: E402 pylint: disable=C0413
import openai


def create_args() -> argparse.Namespace:
Expand Down Expand Up @@ -35,6 +36,9 @@ def create_args() -> argparse.Namespace:
help='Save the API key for accessing OpenAI APIs to override.yaml file.\
\nSubsequent invocations of the program will not require API key on command line.\
\nTo not persist the API key use the -k option.')
cmd_args.add_argument('-vk', '--validate_api_key', action='store', default=None,
help='Validate that it is a valid functioning api_key.\
\nWithout the API Key only transcription works.')
cmd_args.add_argument('-t', '--transcribe', action='store', default=None,
help='Transcribe the given audio file to generate text.\
\nThis option respects the -m (model) option.\
Expand Down Expand Up @@ -64,7 +68,7 @@ def create_args() -> argparse.Namespace:
return args


def handle_args_batch_tasks(args: argparse.Namespace, global_vars: TranscriptionGlobals):
def handle_args_batch_tasks(args: argparse.Namespace, global_vars: TranscriptionGlobals, config: dict):
"""Handle batch tasks, after which the program will exit."""
interactions.params(args)

Expand All @@ -77,6 +81,26 @@ def handle_args_batch_tasks(args: argparse.Namespace, global_vars: Transcription
save_api_key(args)
sys.exit(0)

if args.validate_api_key is not None:
chat_inference_provider = config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
base_url = config['OpenAI']['base_url']
elif chat_inference_provider == 'together':
base_url = config['Together']['base_url']

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url):
print('The api_key is valid')
base_url = config['OpenAI']['base_url']
client = openai.OpenAI(api_key=args.validate_api_key, base_url=base_url)
models = utilities.get_available_models(client=client)
print('Available models: ')
for model in models:
print(f' {model}')
client.close()
else:
print('The api_key is not valid')
sys.exit(0)

if args.transcribe is not None:
with duration.Duration(name='Transcription', log=False, screen=True):
output_file = args.output_file if args.output_file is not None else "transcription.txt"
Expand Down
2 changes: 1 addition & 1 deletion app/transcribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main():
# transcription. This order of initialization results in initialization of Mic, Speaker
# as well which is not necessary for some batch tasks.
# This does not have any side effects.
handle_args_batch_tasks(args, global_vars)
handle_args_batch_tasks(args, global_vars, config)

# Initiate logging
log_listener = al.initiate_log(config=config)
Expand Down
22 changes: 19 additions & 3 deletions app/transcribe/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ def create_ui_components(root, config: dict):

response_enabled = bool(config['General']['continuous_response'])
b_text = "Suggest Responses Continuously" if not response_enabled else "Do Not Suggest Responses Continuously"
freeze_button = ctk.CTkButton(root, text=b_text, command=None)
freeze_button.grid(row=1, column=1, padx=10, pady=3, sticky="nsew")
continuous_response_button = ctk.CTkButton(root, text=b_text, command=None)
continuous_response_button.grid(row=1, column=1, padx=10, pady=3, sticky="nsew")

response_now_button = ctk.CTkButton(root, text="Suggest Response Now", command=None)
response_now_button.grid(row=2, column=1, padx=10, pady=3, sticky="nsew")
Expand Down Expand Up @@ -483,6 +483,22 @@ def create_ui_components(root, config: dict):
m.add_separator()
m.add_command(label="Quit", command=root.quit)

chat_inference_provider = config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
api_key = config['OpenAI']['api_key']
base_url = config['OpenAI']['base_url']
elif chat_inference_provider == 'together':
api_key = config['Together']['api_key']
base_url = config['Together']['base_url']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
# Disable buttons that interact with backend services
continuous_response_button.configure(state='disabled')
response_now_button.configure(state='disabled')
continuous_response_button.configure(state='disabled')
read_response_now_button.configure(state='disabled')
summarize_button.configure(state='disabled')

def show_context_menu(event):
try:
m.tk_popup(event.x_root, event.y_root)
Expand All @@ -494,6 +510,6 @@ def show_context_menu(event):
# Order of returned components is important.
# Add new components to the end
return [transcript_textbox, response_textbox, update_interval_slider,
update_interval_slider_label, freeze_button, lang_combobox,
update_interval_slider_label, continuous_response_button, lang_combobox,
filemenu, response_now_button, read_response_now_button, editmenu,
github_link, issue_link, summarize_button]
7 changes: 5 additions & 2 deletions sdk/audio_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def write_wav_data_to_file(self) -> str:
if self.audio_file_name is None:
return

if not os.path.exists(self.audio_file_name+'.bak'):
return

frame_rate = self.source.SAMPLE_RATE
sample_width = self.source.SAMPLE_WIDTH
channels = self.source.channels
Expand All @@ -172,8 +175,8 @@ def write_wav_data_to_file(self) -> str:
wf.setsampwidth(sample_width) # pylint: disable=E1101
wf.setframerate(frame_rate) # pylint: disable=E1101
wf.writeframes(data) # pylint: disable=E1101
print(f'datasize: {len(data)}')
print(f'filesize: {os.path.getsize(self.audio_file_name)}')
# print(f'datasize: {len(data)}')
# print(f'filesize: {os.path.getsize(self.audio_file_name)}')


class MicRecorder(BaseRecorder):
Expand Down
33 changes: 33 additions & 0 deletions tsutils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import subprocess
import zipfile
import openai


def merge(first: dict, second: dict, path=[]):
Expand Down Expand Up @@ -189,3 +190,35 @@ def zip_files_in_folder(folder_path: str, zip_file_name: str,
if skip_zip_files and file.endswith(".zip"):
continue
my_zip.write(f'{folder_path}/{file}')


def get_available_models(client: openai.OpenAI) -> list:
"""Get the list of available models from the provider.
"""
try:
models = client.models.list()
return_val = []
for model in models.data:
return_val.append(model.id)
except openai.AuthenticationError as e:
print(e)
return None

return sorted(return_val)


def is_api_key_valid(api_key: str, base_url) -> bool:
"""Check if it is valid openai compatible openai key for the provider
"""
openai.api_key = api_key

client = openai.OpenAI(api_key=api_key, base_url=base_url)

try:
client.models.list()
client.close()
except openai.AuthenticationError as e:
print(e)
return False

return True