In [20]:
# %load metric_helpers
import json
import os
import sys
import re
import glob

import pandas as pd
import numpy as np
import jiwer

from difflib import SequenceMatcher
from sklearn.metrics import f1_score
from collections import Counter

def levenshtein(s1, s2):
    if len(s1) < len(s2):
        return levenshtein(s2, s1)
    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)
    previous_row = range(len(s2) + 1)
    for (i, c1) in enumerate(s1):
        current_row = [i + 1]
        for (j, c2) in enumerate(s2):
            insertions = previous_row[j + 1] + 1  # j+1 instead of j since
            deletions = current_row[j] + 1  # than s2
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    return previous_row[-1]


def instance_metrics(ref_labels, hyp_labels):
    segment_records = []
    n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors = 0, 0, 0
    for ref, hyp in zip(ref_labels, hyp_labels):
        n_segment_tokens += 1
        if hyp[0] != ref[0]:
            n_segment_seg_errors += 1
        if hyp != ref:
            n_segment_joint_errors += 1
        if ref.startswith("E"):
            segment_records.append((n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors))
            n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors = 0, 0, 0
    
    n_segments = len(segment_records)
    n_tokens = 0
    n_wrong_seg_segments = 0
    n_wrong_seg_tokens = 0
    n_wrong_joint_segments = 0
    n_wrong_joint_tokens = 0
    for (n_segment_tokens, n_segment_seg_errors, n_segment_joint_errors) in segment_records:
        n_tokens += n_segment_tokens
        if n_segment_seg_errors > 0:
            n_wrong_seg_segments += 1
            n_wrong_seg_tokens += n_segment_tokens
        if n_segment_joint_errors > 0:
            n_wrong_joint_segments += 1
            n_wrong_joint_tokens += n_segment_tokens

    DSER = n_wrong_seg_segments / n_segments
    strict_seg_err = n_wrong_seg_tokens / n_tokens
    DER = n_wrong_joint_segments / n_segments
    strict_joint_err = n_wrong_joint_tokens / n_tokens

    ref_short = [x for x in ref_labels if x != "I"]
    hyp_short = [x for x in hyp_labels if x != "I"]
    lwer = jiwer.wer(ref_short, hyp_short)
    return {
        "DSER": DSER,
        "strict segmentation error": strict_seg_err,
        "DER": DER,
        "strict joint error": strict_joint_err,
        "LWER": lwer
    }

def batch_metrics(refs, hyps):
    score_lists = {
        "DSER": [],
        "strict segmentation error": [],
        "DER": [],
        "strict joint error": [],
        "LWER": []
    }
    for ref_labels, hyp_labels in zip(refs, hyps):
        this_metrics = instance_metrics(ref_labels, hyp_labels)
        for k, v in this_metrics.items():
            score_lists[k].append(v)

    flattened_refs = [label for ref in refs for label in ref]
    flattened_hyps = [label for hyp in hyps for label in hyp]
    macro_f1 = f1_score(flattened_refs, flattened_hyps, average="macro")
    micro_f1 = f1_score(flattened_refs, flattened_hyps, average="micro")
    flat_ref_short = [x for x in flattened_refs if x != "I"]
    flat_hyp_short = [x for x in flattened_hyps if x != "I"]
    lwer = jiwer.wer(flat_ref_short, flat_hyp_short)

    return {
        "DSER": np.mean(score_lists["DSER"]),
        "strict segmentation error": np.mean(score_lists["strict segmentation error"]),
        "DER": np.mean(score_lists["DER"]),
        "strict joint error": np.mean(score_lists["strict joint error"]),
        "Macro F1": macro_f1,
        "Micro F1": micro_f1,
        "Macro LWER": np.mean(score_lists["LWER"]),
        "Micro LWER": lwer,
    }

def instance_metrics_asr(ref_labels, hyp_labels, dist_fn=levenshtein):
    ref_short = [x for x in ref_labels if x != "I"]
    hyp_short = [x for x in hyp_labels if x != "I"]
    lwer = jiwer.wer(ref_short, hyp_short)

    ler = dist_fn(ref_labels, hyp_labels) / len(ref_labels)
    
    t_ids = [i for i, t in enumerate(ref_labels) if "E" in t]
    r_ids = [i for i, r in enumerate(hyp_labels) if "E" in r]
    s = 0
    for t in t_ids: s += min([abs(r - t) for r in r_ids])
    for r in r_ids: s += min([abs(r - t) for t in t_ids])
        
    ser = s / 2 / len(ref_short)
    nser = abs(len(ref_short) - len(hyp_short)) / len(ref_short)
    
    new_ref = []
    new_hyp = []
    offset = 0
    for i in t_ids:
        new_ref += [ref_labels[i]] * (i - offset + 1)
        offset = i+1 
    offset = 0
    for i in r_ids:
        new_hyp += [hyp_labels[i]] * (i - offset + 1)
        offset = i+1 
    daer = dist_fn(new_ref, new_hyp) / len(new_ref)
    return {"LWER": lwer,
            "LER": ler,
            "SER": ser,
            "NSER": nser,
            "DAER": daer}

def convert_to_list(this_str, turn_float=False):
    this_str = this_str.replace('[', '').replace(']','')
    this_str = this_str.replace("'", "").replace(",","").split()
    if turn_float:
        this_str = [float(x) for x in this_str]
    return this_str


In [21]:
ref_dir = "/homes/ttmt001/transitory/dialog-act-prediction/data/joint/ref_out"
asr_dir = "/homes/ttmt001/transitory/dialog-act-prediction/data/joint/asr_out"

split_name = 'dev'
filename = split + "_merged.tsv"
merged_df = pd.read_csv(filename, sep="\t")
for column in ['joint_labels', 'da_turn_orig', 'da_turn_asr']:
    merged_df[column] = merged_df[column].apply(convert_to_list)
for column in ['start_times_orig', 'end_times_orig', 'start_times_asr', 'end_times_asr']:
    merged_df[column] = merged_df[column].apply(convert_to_list, turn_float=True)

In [26]:
def get_results_df(model_name, split_name, merged_df):
    suffix = split_name.upper() + '_' +  model_name + '.res'

    trans_file = os.path.join(ref_dir, suffix)
    asr_file = os.path.join(asr_dir, suffix)

    trans_df = pd.read_csv(trans_file, sep="\t")
    asr_df = pd.read_csv(asr_file, sep="\t")
    asr_df.rename(columns={'PREDS': 'PREDS_ASR'}, inplace=True)
    asr_df['PREDS_ASR'] = asr_df.PREDS_ASR.apply(lambda x: x.replace(" </t>", ""))
    preds_df = trans_df.join(asr_df)
    preds_df['labels'] = preds_df.LABELS.apply(lambda x: x.split())
    preds_df['hyps_trans'] = preds_df.PREDS.apply(lambda x: x.split())
    preds_df['hyps_asr'] = preds_df.PREDS_ASR.apply(lambda x: x.split())
    preds_df.rename(columns={'TURN_ID': 'main_id'}, inplace=True)
    preds_df.drop(columns=['LABELS', 'PREDS', 'PREDS_ASR'], inplace=True)
    res_df = pd.merge(preds_df, merged_df, on='main_id')

    results = res_df.apply(lambda row: instance_metrics(row.labels, row.hyps_trans), axis=1)
    results_asr = res_df.apply(lambda row: instance_metrics_asr(row.labels, row.hyps_asr), axis=1)
    results2 = res_df.apply(lambda row: instance_metrics_asr(row.labels, row.hyps_trans), axis=1)

    res_df['DSER'] = [x['DSER'] for x in results.tolist()]
    res_df['DER'] = [x['DER'] for x in results.tolist()]
    res_df['LWER_trans'] = [x['LWER'] for x in results.tolist()]
    res_df['LER_trans'] = [x['LER'] for x in results2.tolist()]
    res_df['SER_trans'] = [x['SER'] for x in results2.tolist()]
    res_df['NSER_trans'] = [x['NSER'] for x in results2.tolist()]
    res_df['DAER_trans'] = [x['DAER'] for x in results2.tolist()]

    res_df['LWER_asr'] = [x['LWER'] for x in results_asr.tolist()]
    res_df['LER_asr'] = [x['LER'] for x in results_asr.tolist()]
    res_df['SER_asr'] = [x['SER'] for x in results_asr.tolist()]
    res_df['NSER_asr'] = [x['NSER'] for x in results_asr.tolist()]
    res_df['DAER_asr'] = [x['DAER'] for x in results_asr.tolist()]

    return res_df

In [None]:

sp10004_df = get_results_df("sp10004", split_name, merged_df)
tt10004_df = get_results_df("tt10004", split_name, merged_df)


In [None]:
# Row specific (debug)
sseq = SequenceMatcher(None, row.da_turn_asr, row.da_turn_orig)

ref_side = list(zip(range(len(row.labels)),row.labels, row.start_times_orig, row.end_times_orig, row.da_turn_orig))
ref_segments = [x for x in ref_side if "E" in x[1]]
hyp_side = list(zip(range(len(row.hyps_asr)), row.hyps_asr, row.start_times_asr, row.end_times_asr, row.da_turn_asr))
hyp_segments = [x for x in hyp_side if "E" in x[1]]


ref_list = res_df.labels.tolist()
trans_list = res_df.hyps_trans.tolist()
asr_list = res_df.hyps_asr.tolist()

batch_metrics(ref_list, trans_list)

batch_metrics_asr(ref_list, asr_list)
