In [None]:
import os
os.environ['HF_HOME'] = 'e:/.cache/huggingface'

from pathlib import Path

from tqdm.auto import tqdm
import gradio as gr
import datasets
from datasets import load_dataset, Audio, Dataset

In [None]:
dataset_dir = Path('datasets/yodas2_ru000_16k')
yodas2 = load_dataset(str(dataset_dir), split='train')

In [None]:
class SimpleTextSearcher:
    def __init__(self):
        self.database = {}
    def add(self, id: str, text: str):
        self.database[id] = text
    def find(self, query: str, max_count: int | None = None) -> list[str]: # TODO switch to Python 3.10
        found_ids = []
        for id, text in self.database.items():
            if max_count is not None and len(found_ids) >= max_count:
                break
            if query.strip().lower() in text.lower():
                found_ids.append(id)
        return found_ids

searcher = SimpleTextSearcher()
for i, sample in enumerate(tqdm(yodas2)):
    if (dataset_dir / sample['audio']['path']).is_file():
        searcher.add(i, ' '.join(sample['utterances']['text']))

print(f'Added {len(searcher.database)} files')

In [None]:
results_per_tab = 1
n_tabs = 1
    
with gr.Blocks(fill_height=True, theme=gr.themes.Origin()) as demo:
    
    with gr.Row(equal_height=True):
        with gr.Column(scale=3):
            query = gr.Textbox(label="Query", autofocus=True)
            collected = gr.Textbox(label="Collected audios", lines=4)
        transcription = gr.Textbox(label="Transcription", scale=7, lines=9, interactive=False)
    
    def add_result(query: str, collected: str, id: str) -> gr.Textbox:
        new_line = f'{id} (q="{query}")'
        return gr.Textbox(value= f'{collected}, {new_line}' if len(collected) else new_line)
    
    search_result_elements = []
    for tab_idx in range(n_tabs):
        with gr.Tab(f"{tab_idx * results_per_tab + 1}-{(tab_idx + 1) * results_per_tab}"):
            for i in range(results_per_tab):
                with gr.Row(variant='compact', equal_height=True):
                    with gr.Column(scale=1, min_width=0):
                        search_result_elements.append(add_button := gr.Button(visible=False, min_width=0))
                        search_result_elements.append(result_id := gr.Label(visible=False, min_width=0, container=False))
                    search_result_elements.append(gr.Audio(visible=False, scale=15, editable=False))
                    add_button.click(add_result, inputs=[query, collected, result_id], outputs=collected)

    debug_log = gr.Textbox(label="Debug log", lines=3, interactive=False)

    def search(query: str, debug_log: str) -> list[gr.Button]:

        found_ids = searcher.find(query, max_count=results_per_tab * n_tabs)
        debug_log += f'Search query: "{query}\n", found {len(found_ids)} results:'

        returned_elements = []
        for i in range(results_per_tab * n_tabs):
            if i < len(found_ids):
                id = found_ids[i]
                sample = yodas2[id]
                youtube_id = sample['video_id']
                audio_path = str(dataset_dir / sample['audio']['path'])
                print(audio_path)
                # waveform = sample['audio']['array']
                # rate = sample['audio']['sampling_rate']
                filesize = Path(audio_path).stat().st_size
                debug_log += (
                    f'id={id} | {filesize / 1024**2:.0f} MB'  #  | {len(waveform) / rate:.0f} sec
                    # f' | https://www.youtube.com/watch?v={youtube_id}'  # not a youtube id!!
                    '\n'
                )
                returned_elements += [
                    gr.Button(visible=True, value="Add"),
                    gr.Label(visible=False, value=id),  # to save audio id
                    # gr.Audio(visible=True, value=(rate, waveform), label=f'id {id}, youtube_id {youtube_id}'),
                    gr.Audio(visible=True, value=audio_path, label=f'id {id}, youtube_id {youtube_id}'),
                ]
            else:
                returned_elements += [
                    gr.Button(visible=False),
                    gr.Label(visible=False),
                    gr.Audio(visible=False),
                ]

        return returned_elements + [debug_log]
    
    query.submit(search, inputs=[query, debug_log], outputs=search_result_elements + [debug_log])

demo.launch(share=True, allowed_paths=[datasets.config.HF_CACHE_HOME])