# Correlation Analysis

In [33]:
import os
import json
import pandas as pd
import numpy as np
import scipy
from scipy.stats import wilcoxon
from scipy.spatial import distance
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

pd.options.display.float_format = '{:,.3f}'.format


In [34]:
shorten_error_model_name = {
    "random" : "RND",
    "error_model_triphone_rich" : "TR",
    "error_model_pure_diversity" : "PD",
    "error_model_without_diversity_enhancing" : "IC-WDE",
    "error_model" : "IC",
    "asrevolve_error_model_real" : "ASR-EV",
    "word_error_predictor_real/no_word_enhance" : "NWE",
    "word_error_predictor_real/word_enhance": "WE"
}

shorten_finetuned_model_name = {
    "random": "RND",
    "triphone_rich": "TR",
    "pure_diversity": "PD",
    "icassp_without_diversity_enhancing_real_mix": "IC-WDE",
    "icassp_real_mix": "IC",
    "asrevolve_error_model_real": "ASR-EV",
    "word_error_real_mix/no_word_enhance": "NWE",
    "word_error_real_mix/word_enhance": "WE"
}

def shorten_em_name(tools):
    return [shorten_error_model_name[tool] for tool in tools]

def shorten_ft_name(tools) :
    return [shorten_finetuned_model_name[tool] for tool in tools]

tool_short_names = ["RND", "TR", "PD", "IC-WDE", "IC", "ASR-EV", "NWE", "WE"]




### The relative improvement of WER after fine-tuning

In [35]:
with open('result/RQ2.json', 'r') as f:
  data = json.load(f)

asrs = ["quartznet", "hubert", "wav2vec-base"]
datasets = ["YBAA", "ZHAA", "ASI", "TNI", "NCC",
            "TXHC", "EBVS", "ERMS", "YDCK", "YKWK", "THV", "TLV"]
tools = ["random", "triphone_rich", "pure_diversity", "icassp_without_diversity_enhancing_real_mix", "icassp_real_mix",
         "asrevolve_error_model_real", "word_error_real_mix/no_word_enhance", "word_error_real_mix/word_enhance"]


# asrs = ["quartznet"]
# datasets = ["YBAA", "ZHAA", "ASI", "TNI"]
# tools = ["random", "triphone_rich", "pure_diversity"]

finetuned_model_performance_on_test_set = {}
for asr in asrs:
    finetuned_model_performance_on_test_set[asr] = {}
    for dataset in datasets:
        finetuned_model_performance_on_test_set[asr][dataset] = {}
        for tool in tools:
            finetuned_model_performance_on_test_set[asr][dataset][shorten_finetuned_model_name[tool]] = pd.read_csv(
                data[asr][dataset][tool])

finetuned_model_performance_on_test_set


{'quartznet': {'YBAA': {'RND':   Dataset  Size  WER_Seed1  WER_Seed2  WER_Seed3  WER_Avg  CER_Seed1  \
   0    YBAA   100     17.060     17.360     18.390   17.600      6.030   
   1    YBAA   200     16.240     16.240     16.650   16.380      5.810   
   2    YBAA   300     16.550     15.830     15.830   16.070      5.870   
   3    YBAA   400     16.240     15.020     17.160   16.140      5.560   
   
      CER_Seed2  CER_Seed3  CER_Avg  
   0      6.530      6.340    6.300  
   1      5.870      5.830    5.840  
   2      5.610      5.710    5.730  
   3      5.380      5.580    5.510  ,
   'TR':   Dataset  Size  WER_Seed1  WER_Seed2  WER_Seed3  WER_Avg  CER_Seed1  \
   0    YBAA   100     16.750     16.650     17.060   16.820      6.080   
   1    YBAA   200     16.960     16.850     17.160   16.990      5.850   
   2    YBAA   300     16.140     16.960     16.750   16.620      5.850   
   3    YBAA   400     15.530     14.610     15.530   15.220      5.460   
   
      CER_Seed2  

In [36]:
with open('result/original.json', 'r') as f:
  original_data = json.load(f)

original_model_performance_on_test_set = {}

for asr in asrs:
    original_model_performance_on_test_set[asr] = {}
    for dataset in datasets:
        original_model_performance_on_test_set[asr][dataset] = original_data[asr][dataset]["test"]["wer"]

original_model_performance_on_test_set



{'quartznet': {'YBAA': 20.84,
  'ZHAA': 20.75,
  'ASI': 15.38,
  'TNI': 18.28,
  'NCC': 30.35,
  'TXHC': 22.79,
  'EBVS': 36.49,
  'ERMS': 25.94,
  'YDCK': 16.84,
  'YKWK': 20.6,
  'THV': 39.57,
  'TLV': 46.39},
 'hubert': {'YBAA': 11.13,
  'ZHAA': 13.22,
  'ASI': 8.32,
  'TNI': 9.66,
  'NCC': 18.5,
  'TXHC': 12.31,
  'EBVS': 24.29,
  'ERMS': 14.24,
  'YDCK': 11.33,
  'YKWK': 12.17,
  'THV': 28.79,
  'TLV': 35.91},
 'wav2vec-base': {'YBAA': 17.98,
  'ZHAA': 20.35,
  'ASI': 15.7,
  'TNI': 19.63,
  'NCC': 31.81,
  'TXHC': 21.97,
  'EBVS': 37.22,
  'ERMS': 26.45,
  'YDCK': 18.5,
  'YKWK': 18.31,
  'THV': 39.78,
  'TLV': 45.57}}

In [37]:
relative_improvement_of_finetuned_model = {}
for asr in asrs:
    relative_improvement_of_finetuned_model[asr] = {}
    for dataset in datasets:
        relative_improvement_of_finetuned_model[asr][dataset] = {}
        for tool in shorten_ft_name(tools):
            relative_improvement_of_finetuned_model[asr][dataset][tool] = {}
            for metric in ["WER_Seed1", "WER_Seed2", "WER_Seed3", "WER_Avg"]:
                relative_improvement_of_finetuned_model[asr][dataset][tool][metric] = (
                    (original_model_performance_on_test_set[asr][dataset] - finetuned_model_performance_on_test_set[asr][dataset][tool][metric]) / original_model_performance_on_test_set[asr][dataset]).to_list()

relative_improvement_of_finetuned_model


{'quartznet': {'YBAA': {'RND': {'WER_Seed1': [0.18138195777351254,
     0.22072936660268722,
     0.20585412667946254,
     0.22072936660268722],
    'WER_Seed2': [0.1669865642994242,
     0.22072936660268722,
     0.24040307101727446,
     0.2792706333973129],
    'WER_Seed3': [0.11756238003838768,
     0.20105566218809987,
     0.24040307101727446,
     0.1765834932821497],
    'WER_Avg': [0.15547024952015348,
     0.21401151631477933,
     0.2288867562380038,
     0.22552783109404986]},
   'TR': {'WER_Seed1': [0.19625719769673705,
     0.1861804222648752,
     0.22552783109404986,
     0.2547984644913628],
    'WER_Seed2': [0.20105566218809987,
     0.1914587332053742,
     0.1861804222648752,
     0.2989443378119002],
    'WER_Seed3': [0.18138195777351254,
     0.1765834932821497,
     0.19625719769673705,
     0.2547984644913628],
    'WER_Avg': [0.19289827255278308,
     0.1847408829174665,
     0.2024952015355086,
     0.2696737044145873]},
   'PD': {'WER_Seed1': [0.166986564299