# Introduction

In [2]:
import pandas as pd

from utils.hparams import HParam
from datasets.get_dataset import get_dataset

import torch
import numpy as np

from utils.audio import Audio

from model.get_model import get_vfmodel, get_embedder, get_forward
from loss.get_criterion import get_criterion

from torch_mir_eval import bss_eval_sources

import matplotlib.pylab as plt

import IPython.display

import json

from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, State
import plotly.express as px

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def plot_spectrogram(spectrogram, range=None):
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spectrogram, aspect='auto', origin='lower',
                   interpolation='none')
    if range:
        im.set_clim(vmin=range[0], vmax=range[1])
    plt.colorbar(im, ax=ax)
    plt.xlabel('Frames')
    plt.ylabel('Channels')
    plt.tight_layout()

def plotly_spectrogram(spectrogram):
    fig = px.imshow(spectrogram, aspect='auto', origin='lower',
                    color_continuous_scale="Viridis")
    return fig

In [4]:
with open("test_results/libri_test.json", "r") as f:
    test_record = json.load(f)

In [5]:
hp = HParam("config/default.yaml")
audio = Audio(hp["experiment"])
data_config = HParam(test_record["config"])

hp.experiment.use_cuda = False
hp.experiment.dataset = data_config.experiment.dataset
testset = get_dataset(hp, scheme="test")
df = testset.data
device = "cpu"

  for doc in docs:


In [6]:
def get_inference_module(config_p, chkpt_p):
    config = HParam(config_p)["experiment"]
    config.model.pretrained_chkpt = chkpt_p

    # Init model, embedder, optim, criterion
    _audio = Audio(config)
    embedder = get_embedder(config, train=False, device=device)
    model, _ = get_vfmodel(config, train=False, device=device)
    train_forward, _ = get_forward(config)
    criterion = get_criterion(config)


    def inference(index):
        sample = testset.get_item(index, _audio)
        for key in ["dvec_mel", "target_wav", "mixed_wav", "target_stft", "mixed_stft", "mixed_mag", "mixed_phase", "target_mag", "target_phase"]:
            sample[key] = sample[key].unsqueeze(0)
        sample["dvec"] = sample["dvec_mel"]

        with torch.no_grad():
            est_stft, _, loss = train_forward(model, embedder, sample, criterion, device)
            target_wav = sample["target_wav"]
            est_stft = est_stft.detach().numpy()[0]
            
        est_mag, _ = _audio.stft2spec(est_stft)
        est_wav = _audio._istft(est_stft.T, length=len(target_wav[0]))

        fig = plotly_spectrogram(est_mag)

        _est_wav = torch.from_numpy(est_wav).reshape(1, -1)
        _target_wav = target_wav.reshape(1, -1)
        
        sdr,sir,sar,perm = bss_eval_sources(_target_wav,_est_wav,compute_permutation=True)

        return {
            "sdr": sdr.item(),
            "loss": loss.mean().item(),
            "spec": fig,
            "wav": est_wav,
        }

    return inference

In [7]:
pse_big_inference = get_inference_module(test_record["PSE_DCCRN_big"]["config"], test_record["PSE_DCCRN_big"]["chkpt"])
pse_inference = get_inference_module(test_record["PSE_DCCRN"]["config"], test_record["PSE_DCCRN"]["chkpt"])
vf_inference = get_inference_module("config/test_vf.yaml", "chkpt/powlaw_loss_finetune/chkpt_178000.pt")
pse_re_inference = get_inference_module("config/pse_dccrn_re.yaml", "chkpt/pse_dccrn_re/chkpt_120000.pt")
pse_stft_big_inference = get_inference_module("config/pse_dccrn_stft_big.yaml", "chkpt/pse_dccrn_stft_big/chkpt_70000.pt")

  for doc in docs:


In [8]:
import base64

def get_audio_file_b64(f):
    enc = base64.b64encode(open(f, "rb").read())
    return enc.decode()

def get_audio_b64(w, sr=16000):
    enc = base64.b64encode(IPython.display.Audio(w, rate=sr).data)
    return enc.decode()

In [9]:
def audio_sample_div(w, title):
    return html.Div([
        html.H3(title),
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(w)}",
            controls=True
        ),
        # dcc.Graph(figure=plotly_spectrogram(audio.wav2spec(w)[0].T))
    ])

# Main app

In [10]:
app = JupyterDash('Explore test result')

In [11]:
app.layout = html.Div([
    html.Div([
    
        html.Div(children=[
            html.H2("Dimensions"),

            html.Div(children=[
                html.Div(children=[
                    html.Label('X-axis'),
                    dcc.Dropdown(df.columns, None, id="dim-x-axis"),
                ], style={'display': 'inline-block'}),

                html.Div(children=[
                    html.Br(),
                    html.Label('Y-axis'),
                    dcc.Dropdown(df.columns, None, id="dim-y-axis"),
                ], style={'display': 'inline-block'}),
            ], style={'display': 'grid',
                'grid-template-columns': '1fr 1fr',
            }),

            html.Br(),
            html.Label('Color'),
            dcc.Dropdown(df.columns, None, id="dim-color"),

            html.Br(),
            html.Label('Size'),
            dcc.Dropdown(df.columns, None, id="dim-size"),
        ], style={'padding': 10, 'flex': 1}),

        html.Div(children=[
            html.H2("Interactive plot"),
            dcc.Graph(id='scatter-plot')
        ], style={'padding': 10, 'flex': 3}),

    ], style={'display': 'flex', 'flex-direction': 'row'}),

    html.Div(children=[
        html.H2("Sample"),
        html.Div(id="show-sample-area", 
            style={'display': 'grid',
                'grid-template-columns': '1fr 1fr',
                'grid-template-rows': '1fr 1fr',
            })
    ]),

    html.Div(children=[
        html.H2("Inference"),
        html.Button(id='inference-button-state', n_clicks=0, children='Run'),
        html.H3("PSE-DCCRN"),
        html.Div(id="pse-dccrn-inference"),
        html.H3("PSE-DCCRN (big)"),
        html.Div(id="pse-dccrn-big-inference"),
        html.H3("PSE-DCCRN (STFT, big)"),
        html.Div(id="pse-dccrn-stft-big-inference"),
        html.H3("PSE-DCCRN (re-train)"),
        html.Div(id="pse-dccrn-re-inference"),
        html.H3("VoiceFilter"),
        html.Div(id="vf-inference"),
        # html.Audio(
        #     src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
        #     controls=True
        # ),
        # dcc.Graph(figure=result["spec"]),
    ], style={'padding': 10}),
])

In [12]:
@app.callback(
    Output('scatter-plot', 'figure'),
    Input('dim-x-axis', 'value'),
    Input('dim-y-axis', 'value'),
    Input('dim-color', 'value'),
    Input('dim-size', 'value'),
)
def update_plot_dims(x_axis, y_axis, color, size):
    fig = px.scatter(df, x=x_axis, y=y_axis,
                 size=size, color=color, hover_name=None,
                 marginal_x="histogram", marginal_y="histogram",
                 log_x=False, size_max=55)
    fig.update_traces(customdata=df.index)
    fig.update_layout(clickmode='event+select')
    fig.update_layout(transition_duration=500)
    return fig

In [13]:
@app.callback(
    Output('pse-dccrn-big-inference', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('scatter-plot', 'clickData'),
)
def show_inference_pd_big(n_clicks, clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    result = pse_big_inference(idx)
    return [
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
            controls=True
        ),
        dcc.Graph(figure=result["spec"])
    ]

In [14]:
@app.callback(
    Output('pse-dccrn-inference', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('scatter-plot', 'clickData'),
)
def show_inference_pd(n_clicks, clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    result = pse_inference(idx)
    return [
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
            controls=True
        ),
        dcc.Graph(figure=result["spec"])
    ]

In [15]:
@app.callback(
    Output('pse-dccrn-re-inference', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('scatter-plot', 'clickData'),
)
def show_inference_pdr(n_clicks, clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    result = pse_re_inference(idx)
    return [
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
            controls=True
        ),
        dcc.Graph(figure=result["spec"])
    ]

In [16]:
@app.callback(
    Output('pse-dccrn-stft-big-inference', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('scatter-plot', 'clickData'),
)
def show_inference_pdsb(n_clicks, clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    result = pse_stft_big_inference(idx)
    return [
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
            controls=True
        ),
        dcc.Graph(figure=result["spec"])
    ]

In [17]:
@app.callback(
    Output('vf-inference', 'children'),
    Input('inference-button-state', 'n_clicks'),
    State('scatter-plot', 'clickData'),
)
def show_inference_vf(n_clicks, clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    result = vf_inference(idx)
    return [
        html.Audio(
            src=f"data:audio/mpeg;base64,{get_audio_b64(result['wav'])}",
            controls=True
        ),
        dcc.Graph(figure=result["spec"])
    ]

In [18]:
@app.callback(
    Output('show-sample-area', 'children'),
    Input('scatter-plot', 'clickData'),
)
def show_sample(clickData):
    if clickData is None:
        return [
        "Please select a data point from scatter plot first to see the sample detail."
    ]

    idx = clickData["points"][0]["customdata"]
    sample = testset[idx]
    return [
        audio_sample_div(sample['mixed_wav'].numpy(), "Mixed audio"),
        audio_sample_div(sample['dvec_wav'], "Reference audio"),
        audio_sample_div(sample['target_wav'].numpy(), "Target audio"),
        audio_sample_div(sample['interf_wav'].numpy(), "Interference audio"),
    ]

In [19]:
# Run app and display result inline in the notebook
# app.run_server(mode='inline')
app.run_server()

Dash app running on http://127.0.0.1:8050/
