In [44]:
# %load metric_helpers
import json
import os
import sys
import re
import glob

import pandas as pd
import numpy as np
import jiwer

from difflib import SequenceMatcher
from sklearn.metrics import f1_score
from collections import Counter

def levenshtein(s1, s2):
    if len(s1) < len(s2):
        return levenshtein(s2, s1)
    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)
    previous_row = range(len(s2) + 1)
    for (i, c1) in enumerate(s1):
        current_row = [i + 1]
        for (j, c2) in enumerate(s2):
            insertions = previous_row[j + 1] + 1  # j+1 instead of j since
            deletions = current_row[j] + 1  # than s2
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    return previous_row[-1]


def instance_metrics(ref_labels, hyp_labels):
    segment_records = []
    n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors = 0, 0, 0
    for ref, hyp in zip(ref_labels, hyp_labels):
        n_segment_tokens += 1
        if hyp[0] != ref[0]:
            n_segment_seg_errors += 1
        if hyp != ref:
            n_segment_joint_errors += 1
        if ref.startswith("E"):
            segment_records.append((n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors))
            n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors = 0, 0, 0
    
    n_segments = len(segment_records)
    n_tokens = 0
    n_wrong_seg_segments = 0
    n_wrong_seg_tokens = 0
    n_wrong_joint_segments = 0
    n_wrong_joint_tokens = 0
    for (n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors) in segment_records:
        n_tokens += n_segment_tokens
        if n_segment_seg_errors > 0:
            n_wrong_seg_segments += 1
            n_wrong_seg_tokens += n_segment_tokens
        if n_segment_joint_errors > 0:
            n_wrong_joint_segments += 1
            n_wrong_joint_tokens += n_segment_tokens

    DSER = n_wrong_seg_segments / n_segments
    strict_seg_err = n_wrong_seg_tokens / n_tokens
    DER = n_wrong_joint_segments / n_segments
    strict_joint_err = n_wrong_joint_tokens / n_tokens

    ref_short = [x for x in ref_labels if x != "I"]
    hyp_short = [x for x in hyp_labels if x != "I"]
    lwer = jiwer.wer(ref_short, hyp_short)
    return {
        "DSER": DSER,
        "strict segmentation error": strict_seg_err,
        "DER": DER,
        "strict joint error": strict_joint_err,
        "LWER": lwer
    }

def batch_metrics(refs, hyps):
    score_lists = {
        "DSER": [],
        "strict segmentation error": [],
        "DER": [],
        "strict joint error": [],
        "LWER": []
    }
    for ref_labels, hyp_labels in zip(refs, hyps):
        this_metrics = instance_metrics(ref_labels, hyp_labels)
        for k, v in this_metrics.items():
            score_lists[k].append(v)

    flattened_refs = [label for ref in refs for label in ref]
    flattened_hyps = [label for hyp in hyps for label in hyp]
    macro_f1 = f1_score(flattened_refs, flattened_hyps, average="macro")
    micro_f1 = f1_score(flattened_refs, flattened_hyps, average="micro")
    flat_ref_short = [x for x in flattened_refs if x != "I"]
    flat_hyp_short = [x for x in flattened_hyps if x != "I"]
    lwer = jiwer.wer(flat_ref_short, flat_hyp_short)

    return {
        "DSER": np.mean(score_lists["DSER"]),
        "strict segmentation error": np.mean(score_lists["strict segmentation error"]),
        "DER": np.mean(score_lists["DER"]),
        "strict joint error": np.mean(score_lists["strict joint error"]),
        "Macro F1": macro_f1,
        "Micro F1": micro_f1,
        "Macro LWER": np.mean(score_lists["LWER"]),
        "Micro LWER": lwer,
    }

def instance_metrics_asr(ref_labels, hyp_labels, dist_fn=levenshtein):
    ref_short = [x for x in ref_labels if x != "I"]
    hyp_short = [x for x in hyp_labels if x != "I"]
    lwer = jiwer.wer(ref_short, hyp_short)

    ler = jiwer.wer(ref_labels, hyp_labels)
    
    t_ids = [i for i, t in enumerate(ref_labels) if "E" in t]
    r_ids = [i for i, r in enumerate(hyp_labels) if "E" in r]
    s = 0
    for t in t_ids: s += min([abs(r - t) for r in r_ids])
    for r in r_ids: s += min([abs(r - t) for t in t_ids])
        
    ser = s / 2 / len(ref_short)
    nser = abs(len(ref_short) - len(hyp_short)) / len(ref_short)
    
    new_ref = []
    new_hyp = []
    offset = 0
    for i in t_ids:
        new_ref += [ref_labels[i]] * (i - offset + 1)
        offset = i+1 
    offset = 0
    for i in r_ids:
        new_hyp += [hyp_labels[i]] * (i - offset + 1)
        offset = i+1 
    daer = jiwer.wer(new_ref, new_hyp)
    return {"LWER": lwer,
            "LER": ler,
            "SER": ser,
            "NSER": nser,
            "DAER": daer}

def batch_metrics_asr(refs, hyps, dist_fn=levenshtein):
    score_lists = {
        "LWER": [],
        "LER": [],
        "SER": [],
        "NSER": [],
        "DAER": []
    }
    for ref_labels, hyp_labels in zip(refs, hyps):
        this_metrics = instance_metrics_asr(ref_labels, hyp_labels)
        for k, v in this_metrics.items():
            score_lists[k].append(v)

    flattened_refs = [label for ref in refs for label in ref]
    flattened_hyps = [label for hyp in hyps for label in hyp]
    flat_ref_short = [x for x in flattened_refs if x != "I"]
    flat_hyp_short = [x for x in flattened_hyps if x != "I"]
    lwer = jiwer.wer(flat_ref_short, flat_hyp_short)
    ler = jiwer.wer(flattened_refs, flattened_hyps)
    
    t_ids = [i for i, t in enumerate(flattened_refs) if "E" in t]
    r_ids = [i for i, r in enumerate(flattened_hyps) if "E" in r]
    s = 0
    for t in t_ids: s += min([abs(r - t) for r in r_ids])
    for r in r_ids: s += min([abs(r - t) for t in t_ids])
    ser = s / 2 / len(t_ids)
    
    nser = abs(len(t_ids) - len(r_ids)) / len(t_ids)
    
    new_ref = []
    new_hyp = []
    offset = 0
    for i in t_ids:
        new_ref += [flattened_refs[i]] * (i - offset + 1)
        offset = i+1 
    offset = 0
    for i in r_ids:
        new_hyp += [flattened_hyps[i]] * (i - offset + 1)
        offset = i+1 
    daer = jiwer.wer(new_ref, new_hyp)

    return {
        "Macro LWER": np.mean(score_lists["LWER"]),
        "Micro LWER": lwer,
        "Macro LER": np.mean(score_lists["LER"]),
        "Micro LER": ler,
        "Macro SER": np.mean(score_lists["SER"]),
        "Micro SER": ser,
        "Macro NSER": np.mean(score_lists["NSER"]),
        "Micro NSER": nser,
        "Macro DAER": np.mean(score_lists["DAER"]),
        "Micro DAER": daer,
    }

def convert_to_list(this_str, turn_float=False):
    this_str = this_str.replace('[', '').replace(']','')
    this_str = this_str.replace("'", "").replace(",","").split()
    if turn_float:
        this_str = [float(x) for x in this_str]
    return this_str


In [45]:
ref_dir = "/homes/ttmt001/transitory/dialog-act-prediction/data/joint/ref_out"
asr_dir = "/homes/ttmt001/transitory/dialog-act-prediction/data/joint/asr_out"

split_name = 'dev'
filename = split + "_merged.tsv"
merged_df = pd.read_csv(filename, sep="\t")
for column in ['joint_labels', 'da_turn_orig', 'da_turn_asr']:
    merged_df[column] = merged_df[column].apply(convert_to_list)
for column in ['start_times_orig', 'end_times_orig', 'start_times_asr', 'end_times_asr']:
    merged_df[column] = merged_df[column].apply(convert_to_list, turn_float=True)

In [46]:
def get_results_df(model_name, split_name, merged_df):
    suffix = split_name.upper() + '_' +  model_name + '.res'

    trans_file = os.path.join(ref_dir, suffix)
    asr_file = os.path.join(asr_dir, suffix)

    trans_df = pd.read_csv(trans_file, sep="\t")
    asr_df = pd.read_csv(asr_file, sep="\t")
    asr_df.rename(columns={'PREDS': 'PREDS_ASR'}, inplace=True)
    asr_df['PREDS_ASR'] = asr_df.PREDS_ASR.apply(lambda x: x.replace(" </t>", ""))
    preds_df = trans_df.join(asr_df)
    preds_df['labels'] = preds_df.LABELS.apply(lambda x: x.split())
    preds_df['hyps_trans'] = preds_df.PREDS.apply(lambda x: x.split())
    preds_df['hyps_asr'] = preds_df.PREDS_ASR.apply(lambda x: x.split())
    preds_df.rename(columns={'TURN_ID': 'main_id'}, inplace=True)
    preds_df.drop(columns=['LABELS', 'PREDS', 'PREDS_ASR'], inplace=True)
    res_df = pd.merge(preds_df, merged_df, on='main_id')

    results = res_df.apply(lambda row: instance_metrics(row.labels, row.hyps_trans), axis=1)
    results_asr = res_df.apply(lambda row: instance_metrics_asr(row.labels, row.hyps_asr), axis=1)
    results2 = res_df.apply(lambda row: instance_metrics_asr(row.labels, row.hyps_trans), axis=1)

    res_df['DSER'] = [x['DSER'] for x in results.tolist()]
    res_df['DER'] = [x['DER'] for x in results.tolist()]
    res_df['LWER_trans'] = [x['LWER'] for x in results.tolist()]
    res_df['LER_trans'] = [x['LER'] for x in results2.tolist()]
    res_df['SER_trans'] = [x['SER'] for x in results2.tolist()]
    res_df['NSER_trans'] = [x['NSER'] for x in results2.tolist()]
    res_df['DAER_trans'] = [x['DAER'] for x in results2.tolist()]

    res_df['LWER_asr'] = [x['LWER'] for x in results_asr.tolist()]
    res_df['LER_asr'] = [x['LER'] for x in results_asr.tolist()]
    res_df['SER_asr'] = [x['SER'] for x in results_asr.tolist()]
    res_df['NSER_asr'] = [x['NSER'] for x in results_asr.tolist()]
    res_df['DAER_asr'] = [x['DAER'] for x in results_asr.tolist()]

    return res_df

In [47]:

sp10004_df = get_results_df("sp10004", split_name, merged_df)
tt1000_df = get_results_df("tt1000", split_name, merged_df)


In [50]:
batch_metrics(sp10004_df.labels.tolist(), sp10004_df.hyps_trans.tolist())

{'DSER': 0.09619378655376215,
 'strict segmentation error': 0.08684014825474667,
 'DER': 0.2746146013839729,
 'strict joint error': 0.2638527217940235,
 'Macro F1': 0.44831366975399234,
 'Micro F1': 0.9594272076372314,
 'Macro LWER': 0.26291343223924063,
 'Micro LWER': 0.28832116788321166}

In [51]:
batch_metrics(tt1000_df.labels.tolist(), tt1000_df.hyps_trans.tolist())

{'DSER': 0.09823221825967403,
 'strict segmentation error': 0.09068847729502526,
 'DER': 0.2873828513700387,
 'strict joint error': 0.27953951620043216,
 'Macro F1': 0.42853709134282647,
 'Micro F1': 0.9575319387898358,
 'Macro LWER': 0.27096132800830786,
 'Micro LWER': 0.29105839416058393}

In [52]:
batch_metrics_asr(sp10004_df.labels.tolist(), sp10004_df.hyps_asr.tolist())

{'Macro LWER': 0.33321915187381934,
 'Micro LWER': 0.33728710462287104,
 'Macro LER': 0.21200899990605945,
 'Micro LER': 0.10420468903551874,
 'Macro SER': 1.2919854025711257,
 'Micro SER': 10.773874695863746,
 'Macro NSER': 0.11498445964583852,
 'Micro NSER': 0.11344282238442822,
 'Macro DAER': 0.3469696528989709,
 'Micro DAER': 0.2585287098132809}

In [53]:
batch_metrics_asr(tt1000_df.labels.tolist(), tt1000_df.hyps_asr.tolist())

{'Macro LWER': 0.4048026764774782,
 'Micro LWER': 0.37195863746958635,
 'Macro LER': 0.2512693816481347,
 'Micro LER': 0.10866208058402359,
 'Macro SER': 1.2789210061174672,
 'Micro SER': 10.78132603406326,
 'Macro NSER': 0.10952720251072905,
 'Micro NSER': 0.10462287104622871,
 'Macro DAER': 0.4183378454485582,
 'Micro DAER': 0.26881229818896535}

In [48]:
sp10004_df.head(3)

Unnamed: 0,main_id,labels,hyps_trans,hyps_asr,joint_labels,da_turn_orig,start_times_orig,end_times_orig,da_turn_asr,start_times_asr,...,LWER_trans,LER_trans,SER_trans,NSER_trans,DAER_trans,LWER_asr,LER_asr,SER_asr,NSER_asr,DAER_asr
0,2347_B_0001,"[I, I, E_fo_o_fw_""_by_bc]","[I, I, E_fo_o_fw_""_by_bc]","[E_fo_o_fw_""_by_bc]","[I, I, E_fo_o_fw_""_by_bc]","[um, all, right]","[0.005875, 1.087875, 1.25325]","[0.21975, 1.25325, 1.484375]",[alright],[1.11],...,0.0,0.0,0.0,0.0,0.0,0.0,0.666667,2.0,0.0,0.666667
1,2347_A_0002,"[I, I, I, E_%, I, I, I, I, I, I, I, I, I, I, I...","[I, I, I, I, I, I, I, I, I, I, I, I, I, I, I, ...","[I, I, I, I, I, I, I, I, I, I, I, I, I, I, I, ...","[I, I, I, E_%, I, I, I, I, I, I, I, I, I, I, I...","[i, """", ve, uh, as, far, as, i, """", m, concern...","[1.060625, 1.060625, 1.060625, 1.632625, 1.833...","[1.476, 1.476, 1.476, 1.833625, 2.083625, 2.36...","[i, """", ve, uh, as, far, as, i, """", m, concern...","[1.090625, 1.090625, 1.090625, 1.660625, 1.930...",...,0.333333,0.021277,4.333333,0.333333,0.085106,0.333333,0.021277,4.333333,0.333333,0.085106
2,2347_B_0003,"[I, E_aa, I, I, I, I, I, I, I, E_aa, I, I, I, ...","[I, E_b, I, I, I, I, I, I, I, I, I, I, I, I, I...","[I, I, I, I, I, I, I, I, I, I, I, I, I, I, I, ...","[I, E_aa, I, I, I, I, I, I, I, E_aa, I, I, I, ...","[uh, huh, oh, well, i, tend, i, tend, to, agre...","[16.56525, 16.56525, 17.74175, 17.915125, 18.0...","[17.0535, 17.0535, 17.915125, 18.085125, 18.15...","[m, ##hm, i, tend, i, tend, to, agree, ah, in,...","[16.56525, 16.56525, 18.09525, 18.18525, 18.54...",...,0.619048,0.052817,3.047619,0.47619,0.383803,0.571429,0.066901,4.785714,0.428571,0.355634


In [49]:
tt1000_df.head(3)

Unnamed: 0,main_id,labels,hyps_trans,hyps_asr,joint_labels,da_turn_orig,start_times_orig,end_times_orig,da_turn_asr,start_times_asr,...,LWER_trans,LER_trans,SER_trans,NSER_trans,DAER_trans,LWER_asr,LER_asr,SER_asr,NSER_asr,DAER_asr
0,2347_B_0001,"[I, I, E_fo_o_fw_""_by_bc]","[I, I, E_b]",[E_b],"[I, I, E_fo_o_fw_""_by_bc]","[um, all, right]","[0.005875, 1.087875, 1.25325]","[0.21975, 1.25325, 1.484375]",[alright],[1.11],...,1.0,0.333333,0.0,0.0,1.0,1.0,1.0,2.0,0.0,1.0
1,2347_A_0002,"[I, I, I, E_%, I, I, I, I, I, I, I, I, I, I, I...","[I, I, I, I, I, I, I, I, I, I, I, I, I, I, I, ...","[I, I, I, I, I, I, I, I, I, I, I, I, I, I, I, ...","[I, I, I, E_%, I, I, I, I, I, I, I, I, I, I, I...","[i, """", ve, uh, as, far, as, i, """", m, concern...","[1.060625, 1.060625, 1.060625, 1.632625, 1.833...","[1.476, 1.476, 1.476, 1.833625, 2.083625, 2.36...","[i, """", ve, uh, as, far, as, i, """", m, concern...","[1.090625, 1.090625, 1.090625, 1.660625, 1.930...",...,0.333333,0.021277,4.333333,0.333333,0.085106,0.333333,0.021277,4.333333,0.333333,0.085106
2,2347_B_0003,"[I, E_aa, I, I, I, I, I, I, I, E_aa, I, I, I, ...","[I, E_b, I, I, I, I, I, I, I, E_sd, I, I, I, I...","[I, E_aa, I, I, I, I, I, E_aa, I, I, I, I, I, ...","[I, E_aa, I, I, I, I, I, I, I, E_aa, I, I, I, ...","[uh, huh, oh, well, i, tend, i, tend, to, agre...","[16.56525, 16.56525, 17.74175, 17.915125, 18.0...","[17.0535, 17.0535, 17.915125, 18.085125, 18.15...","[m, ##hm, i, tend, i, tend, to, agree, ah, in,...","[16.56525, 16.56525, 18.09525, 18.18525, 18.54...",...,0.428571,0.056338,1.761905,0.142857,0.260563,0.238095,0.070423,2.666667,0.047619,0.197183


In [None]:
# Row specific (debug)
sseq = SequenceMatcher(None, row.da_turn_asr, row.da_turn_orig)

ref_side = list(zip(range(len(row.labels)),row.labels, row.start_times_orig, row.end_times_orig, row.da_turn_orig))
ref_segments = [x for x in ref_side if "E" in x[1]]
hyp_side = list(zip(range(len(row.hyps_asr)), row.hyps_asr, row.start_times_asr, row.end_times_asr, row.da_turn_asr))
hyp_segments = [x for x in hyp_side if "E" in x[1]]


ref_list = res_df.labels.tolist()
trans_list = res_df.hyps_trans.tolist()
asr_list = res_df.hyps_asr.tolist()

batch_metrics(ref_list, trans_list)

batch_metrics_asr(ref_list, asr_list)


In [38]:
a = ["aa", "sv", "sv", "sv", "sv"]
b = ["aa", "sd", "sd", "sd"]

print(levenshtein(a, b) / len(a))
print(jiwer.wer(a, b))

0.8
0.8


In [40]:
tt1000_df['check'] = tt1000_df.apply(lambda row: jiwer.wer(row.labels, row.hyps_trans), axis=1)

In [42]:
tt1000_df[['check', 'LER_trans']]

Unnamed: 0,check,LER_trans
0,0.333333,0.333333
1,0.021277,0.021277
2,0.056338,0.056338
3,0.095238,0.095238
4,0.000000,0.000000
...,...,...
1634,0.148148,0.148148
1635,0.000000,0.000000
1636,0.067797,0.067797
1637,0.038168,0.038168
