# Introduction

In [None]:
import os
import pandas as pd
import glob

from utils.hparams import HParam, Dotdict
from datasets.get_dataset import get_dataset

import torch
import numpy as np

from utils.audio import Audio
import librosa

from model.get_model import get_vfmodel, get_embedder, get_forward
from loss.get_criterion import get_criterion

from torch_mir_eval import bss_eval_sources

import matplotlib.pylab as plt

import IPython.display


import json

from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, State, dash_table, Dash
import plotly.express as px
import plotly.graph_objects as go

import io

In [None]:
from google.cloud import speech
import io


def transcribe_file(speech_file, language="vi-VN"):
    """Transcribe the given audio file."""

    client = speech.SpeechClient()

    with io.open(speech_file, "rb") as audio_file:
        content = audio_file.read()

    audio = speech.RecognitionAudio(content=content)
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=16000,
        language_code=language,
    )

    response = client.recognize(config=config, audio=audio)

    # Each result is for a consecutive portion of the audio. Iterate through
    # them to get the transcripts for the entire audio file.
    for result in response.results:
        # The first alternative is the most likely one for this portion.
        print(u"Transcript: {}".format(result.alternatives[0].transcript))

    return response.results

In [None]:
# %%time
# a, _ = librosa.load("datasets/VinBigdata/speakers/Bongnt_hall_2212/vinfast-vsmart-000004993-Bongnt_hall_2212-HaNoi-MienBac-1960-nu-vinfast-vsmart-1.wav")

In [None]:
# from pathlib import Path

In [None]:
# Path("datasets/VinBigdata/speakers/Bongnt_hall_2212/vinfast-vsmart-000004993-Bongnt_hall_2212-HaNoi-MienBac-1960-nu-vinfast-vsmart-1.wav").exists()

In [None]:
# with open("datasets/LibriSpeech/dev-clean/84/121123/84-121123-0000.txt", "r") as f:
#     t = f.read()
# t.strip()

In [None]:
# Check embedder detail purpose
# embedder_pt = torch.load("chkpt/power_law_zalo_embedder/chkpt_180000.pt", "cpu")

In [None]:
def plot_spectrogram(spectrogram, range=None):
    fig, ax = plt.subplots(figsize=(12, 4))
    im = ax.imshow(spectrogram, aspect='auto', origin='lower',
                   interpolation='none')
    if range:
        im.set_clim(vmin=range[0], vmax=range[1])
    plt.colorbar(im, ax=ax)
    plt.xlabel('Frames')
    plt.ylabel('Channels')

    buf = io.BytesIO() # in-memory files
    plt.savefig(buf, format = "png") # save to the above file object
    plt.close()

    return buf.getbuffer()

def plotly_spectrogram(spectrogram):
    fig = px.imshow(spectrogram, aspect='auto', origin='lower',
                    color_continuous_scale="Viridis")
    return fig

In [None]:
def get_inference_module(config_p, chkpt_p, device):
    config = HParam(config_p)["experiment"]
    config.model.pretrained_chkpt = chkpt_p

    # Init model, embedder, optim, criterion
    _audio = Audio(config)
    embedder = get_embedder(config, train=False, device=device)
    model, _ = get_vfmodel(config, train=False, device=device)
    train_forward, inference_forward = get_forward(config)
    criterion = get_criterion(config)


    def inference(testset, index):
        sample = testset.get_item(index, _audio)
        for key in ["dvec_mel", "target_wav", "mixed_wav", "target_stft", "mixed_stft", "mixed_mag", "mixed_phase", "target_mag", "target_phase"]:
            sample[key] = sample[key].unsqueeze(0)
        sample["dvec"] = sample["dvec_mel"]

        with torch.no_grad():
            est_stft, _, loss = train_forward(model, embedder, sample, criterion, device)
            target_wav = sample["target_wav"]
            est_stft = est_stft.detach().numpy()[0]
            
        est_mag, _ = _audio.stft2spec(est_stft)
        est_wav = _audio._istft(est_stft.T, length=len(target_wav[0]))

        _est_wav = torch.from_numpy(est_wav).reshape(1, -1)
        _target_wav = target_wav.reshape(1, -1)
        
        sdr,sir,sar,perm = bss_eval_sources(_target_wav,_est_wav,compute_permutation=False)

        return {
            "sdr": sdr.item(),
            "loss": loss.mean().item(),
            "spec": est_mag,
            "wav": est_wav,
        }

    return inference

In [None]:
# Preload model
model_df = pd.read_csv("model_list.csv")
device = "cpu"

print("Start importing model from model list")

model_dict = {}
for _, r in model_df.iterrows():
    config, chkpt, _ = r
    try:
        model = get_inference_module(config, chkpt, device)
        model_dict[(config, chkpt)] = model
    except:
        print("Failed with model " + config)

print("Complete")


def load_test_record(path):
    with open(path, "r") as f:
        test_record = json.load(f)

    data_config = Dotdict(test_record["config"])
    hp.experiment.dataset = data_config.experiment.dataset
    testset = get_dataset(hp, scheme="test")

    test_record.pop("data")
    test_record.pop("config")

    dim_sample = testset.data
    if test_record.get("info") is not None:
        data_info = test_record.pop("info")
        dim_sample = dim_sample.join(pd.DataFrame(data_info))
        
    dim_sample = dim_sample.reset_index()

    for e in test_record.keys():
        if test_record[e].get("metrics") is None:
            continue
        for m in test_record[e]["metrics"].keys():
            if not os.path.isfile(test_record[e]["metrics"][m]):
                print(f"Metrics {m} in experiment {e} is missing.")
            else:
                with open(test_record[e]["metrics"][m], "r") as f:
                    result = json.load(f)
                test_record[e]["metrics"][m] = result

    test_record_df = pd.DataFrame(test_record).transpose().reset_index().rename(columns={"index": "experiment"})
    fact_result = pd.concat([pd.DataFrame({"experiment": row["experiment"], **row["metrics"]}) for _, row in test_record_df.iterrows()])
    dim_test = test_record_df.drop("metrics", axis=1)
    del test_record_df

    df = fact_result.join(dim_sample).merge(dim_test, on="experiment")

    if "output_dir" in df.columns:
        df["output_enhance_dir"] = df["output_dir"]+"/"+df["index"].astype(str)+".wav"
        df["output_asr_dir"] = df["output_dir"]+"/"+df["index"].astype(str)+"_asr.json"

    models = {}
    for m in test_record.keys():
        if "dataset_info" in m:
            continue
        try:
            models[m] = model_dict.get((test_record[m]["config"], test_record[m]["chkpt"]))
        except:
            print("Failed with model " + m)
    
    return df, models, testset

In [None]:
hp = HParam("config/default.yaml")
hp.experiment.use_cuda = False
audio = Audio(hp["experiment"])

test_records = list(glob.glob("test_results/*.json"))

In [None]:
import base64

def get_audio_file_b64(f):
    enc = base64.b64encode(open(f, "rb").read())
    return enc.decode()

def get_audio_b64(w, sr=16000):
    # w can be wav array or path
    enc = base64.b64encode(IPython.display.Audio(w, rate=sr, normalize=False).data)
    return enc.decode()

def get_svg_b64(buf):
    enc = base64.b64encode(buf)
    return enc.decode()

def get_audio_block_from_wav(w):
    return html.Audio(
        src=f"data:audio/mpeg;base64,{get_audio_b64(w)}",
        controls=True
    )

def get_audio_block_from_file(p):
    return html.Audio(
        src=f"data:audio/mpeg;base64,{get_audio_b64(p)}",
        controls=True
    )

def get_spectrogram_img(spec):
    return html.Img(src=f"data:image/png;base64,{get_svg_b64(plot_spectrogram(spec))}")

def get_spectrogram_img_from_file(p, sr=16000):
    #w = IPython.display.Audio(p, rate=sr, normalize=False).data
    w, _ = librosa.load(p, sr)
    est_mag, _ = audio.wav2spec(w)
    return get_spectrogram_img(est_mag.T)

In [None]:
def audio_sample_div(w, title):
    spec = audio.wav2spec(w)[0].T
    return html.Div([
        html.H3(title),
        get_audio_block_from_wav(w),
        get_spectrogram_img(spec)
    ])

# Main app

In [None]:
df, _, _ = load_test_record("test_results/libri_gg_test.json")

In [None]:
# State
df = None
models = None
testset = None

In [None]:
# app = Dash('Explore test result', requests_pathname_prefix='/dashboard/')
app = JupyterDash('Explore test result', requests_pathname_prefix='/dashboard/')
# app = JupyterDash('Explore test result')
# app.css.config.serve_locally = False
# app.scripts.config.serve_locally = True
# app.css.config.serve_locally = True

app.css.append_css({
    'external_url': 'https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css'
})

In [None]:
app.layout = html.Div([
    html.P(id="t_current_test_record", style={"hidden": True}),
    
    html.Div(children=[
            
        html.Div(children=[
            html.Div(children=[
                html.Label('Test record'),
                dcc.Dropdown(test_records, id="test-record"),
            ]),

            html.H2("Dimensions"),

            dcc.Loading(
                children=[html.Div(children=[
                    html.Div(children=[
                        html.Label('X-axis'),
                        dcc.Dropdown(id="dim-x-axis"),
                    ], style={'display': 'inline-block'}),

                    html.Div(children=[
                        html.Label('Y-axis'),
                        dcc.Dropdown(id="dim-y-axis"),
                    ], style={'display': 'inline-block'}),
                ], style={'display': 'grid',
                    'grid-template-columns': '1fr 1fr',
                }),

                html.Label('Color'),
                dcc.Dropdown(id="dim-color"),

                html.Label('Size'),
                dcc.Dropdown(id="dim-size"),

                html.H2("Filter"),
                html.Label('Experiment'),
                dcc.Dropdown(id="filter-exp", multi=True),
                ], fullscreen=True, type="cube"
            ),
        ], style={'padding': 10, 'flex': 1}),


        html.Div(children=[
            html.H2("Interactive plot"),
            dcc.RadioItems(["Scatter", "Histogram", "Box plot"], "Histogram", id="plot-type"),
            dcc.Loading(
                children=dcc.Graph(id='main-plot'),
                type="graph",
            )
            # dcc.Graph(id='hist-plot')
        ], style={'padding': 10, 'flex': 3}),


    ], style={'display': 'flex', 'flex-direction': 'row'}),
    
    html.Div(children=[
        html.H2("Statistic"),
        dcc.Loading(
            children=dash_table.DataTable(id="statistic-table", 
                merge_duplicate_headers=True, export_format="csv",
                fixed_columns={'headers': True, 'data': 1}, style_table={'minWidth': '100%'}
            ),
            type="dot",
        ),
    ], style={'padding': 10}),

    html.Div(children=[
        html.H2("Raw table"),
        dcc.Loading(
            children=dash_table.DataTable(id="raw-table", 
                merge_duplicate_headers=True, export_format="csv",
                fixed_columns={'headers': True}, style_table={'minWidth': '100%'},
                page_current=0,
                page_size=5,
                page_action='native',
                filter_action='custom',
                filter_query=''
            ),
            type="dot",
        ),
    ], style={'padding': 10}),

    html.Div(children=[
        html.H2("Sample"),
        dcc.Loading(
            children=
            [
                html.Div(id="sample-statistic-table"),
                html.Div(id="show-sample-area", 
                    style={'display': 'grid',
                        'grid-template-columns': '1fr 1fr',
                        'grid-template-rows': '1fr 1fr',
                    }),
            ],
            type="dot"
        ),
    ], style={'padding': 10}),

    html.Div(children=[
        html.H2("Inference"),
        html.Button(id='inference-button-state', n_clicks=0, children='Run', className="btn btn-primary"),
        dcc.Loading(children=html.Div(id="inference-area"), type="dot"),
    ], style={'padding': 10}),
])

In [None]:
@app.callback(
    Output('dim-x-axis', 'options'),
    Output('dim-y-axis', 'options'),
    Output('dim-color', 'options'),
    Output('dim-size', 'options'),
    Output('filter-exp', 'options'),
    Output("t_current_test_record", "value"),
    Input('test-record', 'value'),
)
def change_test_record(test_record_path):
    global df
    global models
    global testset
    if test_record_path is not None:
        _df, _models, _testset = load_test_record(test_record_path)
        df = _df
        models = _models
        testset = _testset
        return df.columns, df.columns, df.columns, df.columns, df["experiment"].drop_duplicates(), test_record_path
    else:
        return tuple([[]]*5), "None"

In [None]:
@app.callback(
    Output('raw-table', 'columns'),
    Input('t_current_test_record', 'value'),
)
def change_test_record_raw_tables(_):
    return [{"name": i, "id": i, "hideable": True} for i in sorted(df.columns)]

In [None]:
operators = [['ge ', '>='],
             ['le ', '<='],
             ['lt ', '<'],
             ['gt ', '>'],
             ['ne ', '!='],
             ['eq ', '='],
             ['contains '],
             ['datestartswith ']]

def split_filter_part(filter_part):
    for operator_type in operators:
        for operator in operator_type:
            if operator in filter_part:
                name_part, value_part = filter_part.split(operator, 1)
                name = name_part[name_part.find('{') + 1: name_part.rfind('}')]

                value_part = value_part.strip()
                v0 = value_part[0]
                if (v0 == value_part[-1] and v0 in ("'", '"', '`')):
                    value = value_part[1: -1].replace('\\' + v0, v0)
                else:
                    try:
                        value = float(value_part)
                    except ValueError:
                        value = value_part

                # word operators need spaces after them in the filter string,
                # but we don't want these later
                return name, operator_type[0].strip(), value

    return [None] * 3

@app.callback(
    Output('raw-table', 'data'),
    Input('raw-table', "page_current"),
    Input('raw-table', "page_size"),
    Input('raw-table', "filter_query"),
    Input('raw-table', 'columns'),
)
def update_table(page_current,page_size, filter, _):
    print(filter)
    filtering_expressions = filter.split(' && ')

    dff = df
    for filter_part in filtering_expressions:
        col_name, operator, filter_value = split_filter_part(filter_part)

        if operator in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
            # these operators match pandas series operator method names
            dff = dff.loc[getattr(dff[col_name], operator)(filter_value)]
        elif operator == 'contains':
            dff = dff.loc[dff[col_name].str.contains(filter_value)]
        elif operator == 'datestartswith':
            # this is a simplification of the front-end filtering logic,
            # only works with complete fields in standard format
            dff = dff.loc[dff[col_name].str.startswith(filter_value)]

    return dff.to_dict('records')
    # .iloc[
    #     page_current*page_size:(page_current+ 1)*page_size
    # ].to_dict('records')

In [None]:
@app.callback(
    Output('main-plot', 'figure'),
    Input('plot-type', 'value'),
    Input('dim-x-axis', 'value'),
    Input('dim-y-axis', 'value'),
    Input('dim-color', 'value'),
    Input('dim-size', 'value'),
    Input('filter-exp', 'value'),
    Input('test-record', 'value'),
)
def update_plot_dims(plot_type, x_axis, y_axis, color, size, exp_filtered, _):
    if df is None:
        return px.scatter()
        
    if exp_filtered is not None:
        _data = df[df["experiment"].isin(exp_filtered)]
    else:
        _data = df
    if plot_type == "Scatter":
        fig = px.scatter(_data, x=x_axis, y=y_axis,
                    size=size, color=color, hover_name=None,
                    marginal_x="histogram", marginal_y="histogram",
                    log_x=False, size_max=55)
    elif plot_type == "Histogram":
        fig = px.histogram(_data, x=x_axis, y=y_axis,
                    color=color, hover_name=None,
                    log_x=False)
        fig.update_layout(barmode='overlay')
        # Reduce opacity to see both histograms
        fig.update_traces(opacity=0.75)
    elif plot_type == "Box plot":
        fig = px.box(_data, x=x_axis, y=y_axis,
                    color=color, hover_name=None,
                    log_x=False)

    fig.update_traces(customdata=df["index"])
    fig.update_layout(clickmode='event+select')
    fig.update_layout(transition_duration=500)
    return fig

In [None]:
@app.callback(
    Output('statistic-table', 'data'),
    Output('statistic-table', 'columns'),
    Input('filter-exp', 'value'),
    Input('test-record', 'value'),
)
def update_stat_table(exp_filtered, _):
    
    if exp_filtered is not None or exp_filtered is not []:
        _data = df[df["experiment"].isin(exp_filtered)]
    else:
        _data = df

    _tab = []
    _col = {}
    for row in _data.groupby("experiment").describe().reset_index().to_dict("records"):
        t = {}
        for key, value in row.items():
            col_idx = "_".join(key) 
            t[col_idx] = value
            _col[col_idx] = list(key)
        _tab.append(t)
    
    return _tab, [{"name": v, "id": k, "hideable": True} for k, v in _col.items()]

In [None]:
@app.callback(
    Output('inference-area', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('main-plot', 'clickData'),
    State('filter-exp', 'value'),
    Input('test-record', 'value'),
)
def show_inference(n_clicks, clickData, exp_filtered, _):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]

    # if exp_filtered is None or exp_filtered is not []:
    if True:
        exp_filtered = df["experiment"].drop_duplicates()

    block = []
    for e in exp_filtered:
        if "dataset_info" in e: continue
        sample = df[(df["index"] == idx) & (df["experiment"] == e)].iloc[0]
        asr_pred = "(Not available)"
        asr_conf = "(Not available)"

        if os.path.isfile(sample["output_asr_dir"]):
            with open(sample["output_asr_dir"], "r") as f:
                t = json.load(f)
                asr_pred = t["transcript"]
                asr_conf = t["confidence"]

        # result = models[e](testset, idx)
        block += [
            html.H3(f"Experiment: {e}"),
            html.Div(children=[
                get_audio_block_from_file(sample["output_enhance_dir"]),
                get_spectrogram_img_from_file(sample["output_enhance_dir"])
            ],id=f"{e}-inference"),
            html.P(f"ASR inference: {asr_pred}"),
            html.P(f"ASR confidence: {asr_conf}")
        ]

    return block

In [None]:
@app.callback(
    Output('show-sample-area', 'children'),
    Output('sample-statistic-table', 'children'),
    Input('main-plot', 'clickData'),
    State('filter-exp', 'value'),
    Input('test-record', 'value'),
)
def show_sample(clickData, exp_filtered, _):
    if clickData is None:
        return [
        "Please select a data point from interative plot first to see the sample detail."
    ]

    if exp_filtered is None or exp_filtered is not []:
        exp_filtered = df["experiment"].drop_duplicates()

    idx = clickData["points"][0]["customdata"]
    sample = testset[idx]
    sample_df = df[(df["index"] == idx) & (df["experiment"].isin(exp_filtered))]

    try:
        sample['mixed_wav'] = sample['mixed_wav'].numpy()
    except:
        pass

    try:
        sample['target_wav'] = sample['target_wav'].numpy()
    except:
        pass

    return [
        audio_sample_div(sample['mixed_wav'], "Mixed audio"),
        audio_sample_div(sample['dvec_wav'], "Reference audio"),
        audio_sample_div(sample['target_wav'], "Target audio"),
        # audio_sample_div(sample['interf_wav'], "Interference audio"),
    ], [
        html.Div(dash_table.DataTable(
            sample_df.to_dict("records"),
            [{"name": i, "id": i} for i in sample_df.columns],
            fixed_columns={'headers': True, 'data': 1}, style_table={'minWidth': '100%'},
            export_format="csv"
        )),
        html.Div([
            html.H5("Data point index: "),
            html.P(idx)
        ]),
        html.Div([
            html.H5("Prompt: "),
            html.P(sample["target_text"])
        ]),
    ]

# Run

In [None]:
# Run app and display result inline in the notebook
# app.run_server(mode='inline')
app.run_server(port=8050)