In [181]:
import numpy as np
import pandas as pd
import os
import math
from scipy import stats
from sklearn.metrics import cohen_kappa_score

pd.set_option('display.max_rows', 500)

In [182]:
# # Results on the first 8 methods
# results = pd.read_csv('mturk/sc09-unconditional-exp-confident/Batch_4656948_batch_results.csv')

# uids = []
# for method in [
#     'mturk/sc09-unconditional-exp-confident-diffwave-1m/',
#     'mturk/sc09-unconditional-exp-confident-diffwave-500k/',
#     'mturk/sc09-unconditional-exp-confident-samplernn-3/',
#     'mturk/sc09-unconditional-exp-confident-sashimi-8-glu/',
#     'mturk/sc09-unconditional-exp-confident-sashimi-diffwave-500k/',
#     'mturk/sc09-unconditional-exp-confident-test/',
#     'mturk/sc09-unconditional-exp-confident-wavegan/',
#     'mturk/sc09-unconditional-exp-confident-wavenet-1024/',
# ]:
#     uids.append(pd.read_csv(
#         f'{method}/uids.txt', 
#         sep=' ', 
#         header=None, 
#         names=['method', 'filename','uid'],
#     ))
    
# uids = pd.concat(uids, axis=0)

In [183]:
# Results on the next 4 methods
results = pd.read_csv('Batch_4657899_batch_results.csv')

uids = []
for method in [
    '../final/sc09-unconditional-exp-confident-diffwave-small-500k/',
    '../final/sc09-unconditional-exp-confident-sashimi-diffwave-small-500k/',
    '../final/sc09-unconditional-exp-confident-sashimi-diffwave-800k/',
    '../final/sc09-unconditional-exp-confident-sashimi-diffwave-snet-uni-500k/',
    # replace or add your methods
]:
    uids.append(pd.read_csv(
        f'{method}/uids.txt', 
        sep=' ', 
        header=None, 
        names=['method', 'filename','uid'],
    ))
    
uids = pd.concat(uids, axis=0)

In [187]:
responses = []
for col in results.columns:
    if col.startswith('Input'):
        index = int(col.split("_")[-2])
        responses.append(
            results[
                [col, f'Answer.recording_{index}_intelligibility', f'Answer.recording_{index}_digit', 'WorkerId', 'Answer.diversity', 'Answer.quality']
            ]
        )

In [189]:
responses = [e.rename(columns=dict(zip(e.columns, ['url', 'intelligibility', 'digit', 'worker', 'diversity', 'quality']))) for e in responses]
responses = pd.concat(responses, axis=0)
responses['uid'] = responses['url'].apply(lambda x: x.split("/")[-1].replace(".wav", ""))
responses['method'] = responses['url'].apply(lambda x: x.split("/")[-3].replace("sc09-unconditional-exp-confident-", ""))
responses = responses[['uid', 'method', 'worker', 'digit', 'intelligibility', 'quality', 'diversity']]

In [190]:
data = pd.merge(responses, uids, on=['uid', 'method'], how='inner')
data = data.sort_values(['filename', 'method']).reset_index(drop=True)

In [191]:
# Filenames are sorted by digit class (e.g. 0.wav, ..., 49.wav are digit zero and so on)
classes = pd.DataFrame({
    'filename': [f'{i}.wav' for i in range(500)],
    'label': [(i // 50) for i in range(500)],
})

In [192]:
data = pd.merge(data, classes, on=['filename'], how='inner')

In [195]:
# Calculate accuracy
data['agreement'] = (data['label'] == data['digit'])

def kappa(row):
    row['kappa'] = cohen_kappa_score(row['label'], row['digit'])
    return row

In [197]:
data = data[~data.index.isin(np.where(np.isnan(data['digit']))[0])]

In [198]:
data

Unnamed: 0,uid,method,worker,digit,intelligibility,quality,diversity,filename,label,agreement
0,dc0520a487ba3b901e415c4e57030ede,diffwave-small-500k,A1NF6PELRKACS9,0,1,2,3,0.wav,0,True
1,dc0520a487ba3b901e415c4e57030ede,diffwave-small-500k,A26RPQDD0RQEHL,0,1,3,3,0.wav,0,True
2,dc0520a487ba3b901e415c4e57030ede,diffwave-small-500k,A3CFNUD7VR2E1E,3,1,2,4,0.wav,0,False
3,dc0520a487ba3b901e415c4e57030ede,diffwave-small-500k,A3CJVRJ34U70Y9,0,1,2,4,0.wav,0,True
4,dc0520a487ba3b901e415c4e57030ede,diffwave-small-500k,A3DU2EWFUGQCX4,0,2,2,4,0.wav,0,True
...,...,...,...,...,...,...,...,...,...,...
19995,4cea2df00a66dc4e21681081399f8a8f,sashimi-diffwave-snet-uni-500k,A3M3HUU77NKTES,1,4,5,5,99.wav,1,True
19996,4cea2df00a66dc4e21681081399f8a8f,sashimi-diffwave-snet-uni-500k,A3QZMGTVA4VO44,1,5,5,5,99.wav,1,True
19997,4cea2df00a66dc4e21681081399f8a8f,sashimi-diffwave-snet-uni-500k,AEF601SQFOSBL,1,4,3,3,99.wav,1,True
19998,4cea2df00a66dc4e21681081399f8a8f,sashimi-diffwave-snet-uni-500k,AM9XH69KBK5X5,1,3,3,3,99.wav,1,True


In [200]:
data.groupby('method').apply(kappa).groupby('method').mean().sort_values('intelligibility')[['intelligibility', 'quality', 'diversity', 'agreement', 'kappa']]

Unnamed: 0_level_0,intelligibility,quality,diversity,agreement,kappa
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
diffwave-small-500k,1.895,1.85,3.032,0.5014,0.446
sashimi-diffwave-snet-uni-500k,3.2916,3.076,3.26,0.8462,0.829111
sashimi-diffwave-small-500k,4.0034,3.832,3.338,0.9406,0.934
sashimi-diffwave-800k,4.3292,4.2,3.284,0.9578,0.953111


In [204]:
def calc_stats(df, col, div=1.):
    stats = df.groupby(['method'])[col].agg(['mean', 'count', 'std'])
    ci95_hi = []
    ci95_lo = []

    for i in stats.index:
        m, c, s = stats.loc[i]
        c /= div
        ci95_hi.append(m + 1.96*s/math.sqrt(c))
        ci95_lo.append(m - 1.96*s/math.sqrt(c))

    stats['ci95_hi'] = ci95_hi
    stats['ci95_lo'] = ci95_lo
    return stats

In [205]:
calc_stats(data, 'quality', div=10.).round(2)

Unnamed: 0_level_0,mean,count,std,ci95_hi,ci95_lo
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
diffwave-small-500k,1.85,5000,0.88,1.93,1.77
sashimi-diffwave-800k,4.2,5000,0.67,4.26,4.14
sashimi-diffwave-small-500k,3.83,5000,0.77,3.9,3.76
sashimi-diffwave-snet-uni-500k,3.08,5000,0.86,3.15,3.0


In [206]:
calc_stats(data, 'intelligibility').round(2)

Unnamed: 0_level_0,mean,count,std,ci95_hi,ci95_lo
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
diffwave-small-500k,1.9,5000,1.12,1.93,1.86
sashimi-diffwave-800k,4.33,5000,0.94,4.36,4.3
sashimi-diffwave-small-500k,4.0,5000,1.08,4.03,3.97
sashimi-diffwave-snet-uni-500k,3.29,5000,1.31,3.33,3.26


In [207]:
calc_stats(data, 'diversity', 10.).round(2)

Unnamed: 0_level_0,mean,count,std,ci95_hi,ci95_lo
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
diffwave-small-500k,3.03,5000,1.17,3.13,2.93
sashimi-diffwave-800k,3.28,5000,1.16,3.39,3.18
sashimi-diffwave-small-500k,3.34,5000,1.02,3.43,3.25
sashimi-diffwave-snet-uni-500k,3.26,5000,0.86,3.34,3.18
