# Code for plotting informative prediction error table in report

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from torchaudio.functional import edit_distance as edit_dist

def color_code(pd_df):
    '''
        for each prediction in the matrix (rows = number of test examples, cols = (x, y_true, y_pred1, y_pred2 ...)),
        we give a color. This will be useful for plotting the test errors in the report using plotly tables.
    '''
    colorlist = ["mediumseagreen", "lightgreen", "yellow", "orange", "tomato"]
    rws, cols = len(pd_df), len(pd_df.columns)
    colors = [['#FFFFFF' for _ in range(rws)] for _ in range(cols)]
    for i in range(rws):
        for j in range(1, cols):
            # compute edit distance between true word and predicted word
            true_word = pd_df.iat[i,1]
            pred_word = pd_df.iat[i,j]
            edit_distance = edit_dist(true_word, pred_word)
            clip_edit = min(edit_distance, 4)
            colors[j][i] = colorlist[clip_edit]
    return colors

# pd_df is a dataframe where 1st col = X, 2nd col = y, 3rd col = y_pred1, 4th col = y_pred2 ....
# locs are the locations that we want to display.
def generate_table_and_legend(pd_df, locs):
    df_fil = pd_df.filter(items=locs, axis=0)
    colors = color_code(df_fil)
    table = go.Table(header=dict(values=df_fil.columns),
                     cells=dict(values=[list(df_fil[c]) for c in df_fil.columns], fill_color=colors))
    fig = go.Figure(data=[table])
    fig.show() # replace with wandb log

    colorlist = ["mediumseagreen", "lightgreen", "yellow", "orange", "tomato"]
    
    table = go.Table(header=dict(values=['Color', 'Levenshtein distance']),
                     cells=dict(values=[['' for _ in range(len(colorlist))], [i for i in range(4)] + ['>= 4']], 
                                fill_color=[colorlist, ['#FFFFFF' for _ in range(len(colorlist))]]
                                ))
    fig = go.Figure(data=[table])
    fig.show() # replace with wandb log

In [None]:
df = pd.DataFrame([[str(i) + 'a', 'abc', '1abcde2', 'abc', 'ab'] for i in range(5)], columns=['1', '2', '3', '4', '5'])
generate_table_and_legend(df, [0,2,4])