# Evaluation

## Preliminaries

### Imports

In [None]:
import os

import pickle
import bz2

In [None]:
import pandas as pd

In [None]:
import torch
import torch.nn.functional as F

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
from mellotron_api import load_tts, load_vocoder, load_arpabet_dict
from mellotron_api import get_gst_embeddings, get_gst_scores
from mellotron_api.api import _synthesise_speech_mellotron

In [None]:
from gsttransformer.speech_api import ChatSpeechGenerator

### Constants

In [None]:
RAW_DATA_PATH = '../resources/data/raw'
DATA_PATH = '../resources/data/cache'
OUTPUT_PATH = '../resources/data/eval'

if not os.path.exists(OUTPUT_PATH):
    os.mkdir(OUTPUT_PATH)

In [None]:
SPLITS = ['train', 'validation', 'test']

In [None]:
OUT_DF_COLUMNS = [
    'Model', 'Params [M]', 'Approach', 'Split', 'Audio file path', 'MSE', 'KL-Divergence', 'Frobenius norm (embeddings)', 'Frobenius norm (combination weights)'
]

In [None]:
MODEL_PATH = ''
THERAPY_MODEL_PATH = ''

### Helper functions

In [None]:
def load_data(path):
    with bz2.BZ2File(path, 'rb') as f:
        data = pickle.load(f)

    return data

## Data

In [None]:
data = {
    split: load_data(os.path.join(DATA_PATH, f'gstt_corpus_{split}.pbz2'))
    for split in SPLITS[1:]
}

## Models

In [None]:
mellotron, mellotron_stft, mellotron_hparams = load_tts('resources/tts/mellotron/mellotron_libritts.pt')
tacotron2, tacotron2_stft, tacotron2_hparams = load_tts('resources/tts/tacotron_2/tacotron2_statedict.pt', model='tacotron2')
waveglow, denoiser = load_vocoder('resources/vocoder/waveglow/waveglow_256channels_universal_v4.pt')
arpabet_dict = load_arpabet_dict('mellotron/data/cmu_dictionary')

In [None]:
language_models = {
    'DialoGPT (117M)': 'microsoft/DialoGPT-small',
    'DialoGPT (345M)': 'microsoft/DialoGPT-medium',
    'DialoGPT (762M)': 'microsoft/DialoGPT-large',
    'Therapy-DLDLM': ''
}

model_size_mapping = {
    'DialoGPT-117M': 117,
    'DialoGPT-345M': 345,
    'DialoGPT-762M': 762,
    'Therapy-DLDLM': 762
}

dgpt_mapping = {
    'lm_small': 'DialoGPT-117M',
    'lm_medium': 'DialoGPT-345M',
    'lm_large': 'DialoGPT-762M'
}

In [None]:
approaches_mapping = {
    'resp': 'Response',
    'resp_from_ctx': 'Response (from context)',
    'ctx_resp': 'Context and response'
}

dgst_models_dict = {
    **{
        (dgpt_mapping[dgpt_mapping[model]]): (
            os.path.join(MODEL_PATH, 'model', f'best_checkpoint_{model}_{approach}', 'gstt.pt'),
            language_models[dgpt_mapping[model]],
            'gpt2',
            {'encoding_mode': approach, 'max_context_len': 256}
        )
        for model in dgpt_mapping for approach in approaches_mapping
    },
    **{
        ('Therapy-DLDLM', approaches_mapping[approach]): (
            os.path.join(THERAPY_MODEL_PATH, 'model', f'best_checkpoint_lm_large_{approach}', 'gstt.pt'),
            language_models('Therapy-DLDLM'),
            language_models('Therapy-DLDLM'),
            {
                'encoding_mode': approach,
                'prefix_token': '<|prior|>',
                'suffix_token': '<|posterior|>',
                'max_context_len': 256
            }
        )
        for approach in approaches_mapping
    }
}

In [None]:
out_data = []

with torch.no_grad():
    tgt_data = {
        sample['audio_file_path']: (
            get_gst_embeddings(
                os.path.join(DATA_PATH, sample['audio_file_path']), mellotron, mellotron_stft, mellotron_hparams
            ),
            get_gst_scores(
                os.path.join(DATA_PATH, sample['audio_file_path']), mellotron, mellotron_stft, mellotron_hparams
            ),
            _synthesise_speech_mellotron(
                sample['utterance'], mellotron, mellotron_stft, mellotron_hparams, arpabet_dict,
                reference_audio_path=os.path.join(DATA_PATH, sample['audio_file_path'])
            )[1]
        )
        for split, samples in data.items() for sample in samples
    }
    for (lm_id, dgst_approach), (dgst_path, lm_path, tokeniser_path, kwargs) in dgst_models_dict.items():
        dgst = ChatSpeechGenerator(
            dgst_path, lm_path, tokeniser_path, **kwargs
        )
        for split, samples in data.items():
            for sample in samples:
                gst_embeddings, gst_scores = dgst._predict_gst(sample['utterance'], dialogue=sample['context'])
                mel_spec_embeds = _synthesise_speech_mellotron(
                    sample['utterance'], mellotron, mellotron_stft, mellotron_hparams, arpabet_dict,
                    reference_audio_path=os.path.join(DATA_PATH, sample['audio_file_path']),
                    gst_style_embedding=gst_embeddings
                )[1]
                mel_spec_weights = _synthesise_speech_mellotron(
                    sample['utterance'], mellotron, mellotron_stft, mellotron_hparams, arpabet_dict,
                    reference_audio_path=os.path.join(DATA_PATH, sample['audio_file_path']),
                    gst_head_style_scores=gst_scores
                )[1]

                tgt_mel_spec, tgt_gst_embeddings, tgt_gst_scores = tgt_data[sample['audio_file_path']]

                mse = F.mse_loss(
                    torch.tensor(gst_embeddings),torch.tensor(tgt_gst_embeddings), reduction='none'
                ).mean(-1)
                kl = F.kl_div(
                    torch.tensor(gst_scores).log(), torch.tensor(tgt_gst_scores).log(), reduction='none', log_target=True
                ).sum(-1).mean(1)
                frob_embeds = (((tgt_mel_spec - mel_spec_embeds) ** 2)).sum() ** 0.5
                frob_weights = (((tgt_mel_spec - mel_spec_weights) ** 2)).sum() ** 0.5

                out_data.append([
                    lm_id, model_size_mapping[lm_id], dgst_approach, split, sample['audio_file_path'],
                    mse.item(), kl.item(), frob_embeds.item(), frob_weights.item()
                ])

In [None]:
out_df = pd.DataFrame(out_data, columns=OUT_DF_COLUMNS)
out_df

In [None]:
out_df.to_csv(os.path.join(OUTPUT_PATH, 'results.csv'), index=False)

In [None]:
for split in out_df['Split'].unique():
    for metric in ['MSE', 'KL-Divergence', 'Frobenius norm (embeddings)', 'Frobenius norm (combination weights)']:
        fig, axes = plt.subplots(
            nrows=1,
            ncols=3,
            figsize=(12, 5),
            sharex=True,
            sharey=True
        )
        for i, approach in enumerate(['Response', 'Response (from context)', 'Context and response']):
            order = ['DialoGPT-117M', 'DialoGPT-345M', 'DialoGPT-762M', 'Therapy-DLDLM']
            sns.countplot(
                data=out_df[(out_df['Split'] == split) & (out_df['Approach'] == approach)],
                x=metric, y='Model', order=order, ax=axes[i], linewidth=1., edgecolor='0', orient='h', errorbar='sd'
            )
            axes[i].set_title(f'Approach: {approach}')
            axes[i].set_xscale('log')
            # axes[i].set_xlim([0, 1.])
            axes[i].set_xlabel(metric)

        plt.tight_layout()
        plt.show()

        fig.savefig(
            os.path.join(
                OUTPUT_PATH,
                f'dgst_results_{split.lower()}_{metric.lower("(", "").replace(")", "").replace(" ", "_").replace("-", "_")}.pdf'
            ),
            bbox_inches='tight'
        )