This notebook shows the prediction heatmap. 

- color: discourse type
- color intensity: score of prediction
- under and bold : start of the discourse unit


You can observe how the target text span is fragmented. 

For some cases, it is easy to think of post processing of learning methods to correct the fragmentation.

In [None]:
import numpy as np
import sys
import os
import pandas as pd
from IPython.core.display import display, HTML

discourse_type_to_label = {
    'O'                    : 0,
    'Lead'                 : 1,
    'Position'             : 2,
    'Claim'                : 3,
    'Counterclaim'         : 4,
    'Rebuttal'             : 5,
    'Evidence'             : 6,
    'Concluding Statement' : 7,
}
label_to_discourse_type = {v: k for k, v in discourse_type_to_label.items()}
num_discourse_type = 8


BLACK      = [  0, 0,   0]
WHITE      = [255,255,255]
RED        = [255,  0,  0]
GREEN      = [  0,255,  0]
BLUE       = [  0,  0,255]
YELLOW     = [255,255,  0]
MAGENTA    = [255,  0,255]
CYAN       = [  0,255,255]

ORANGE     = [255,215,  0]
LIGHT_BLUE = [  0,196,255]

discourse_type_color ={
    'O'                   : WHITE,
    'Lead'                : RED,
    'Position'            : LIGHT_BLUE,
    'Claim'               : GREEN,
    'Counterclaim'        : YELLOW,
    'Rebuttal'            : ORANGE,
    'Evidence'            : CYAN,
    'Concluding Statement': MAGENTA,
}

In [None]:
#helper function

def rgbstring( rgb, alpha ):
    #e.g. rgb, alpha = [ 255,0,0 ], 0.5
    string = "\x1b[48;2;" + \
             str(int((1-alpha)*255 + rgb[0]*alpha)) + ";" + \
             str(int((1-alpha)*255 + rgb[1]*alpha)) + ";" + \
             str(int((1-alpha)*255 + rgb[2]*alpha)) + "m"
    return string


def read_text_and_clean(text_file):
    with open(text_file, 'r') as f:
        text = f.read()

    text = text.replace(u'\xa0', u' ')
    text = text.rstrip()
    return text


def text_to_word(text):
    word = text.split()
    word_offset = []

    start = 0
    for w in word:
        r = text[start:].find(w)

        if r==-1:
            raise NotImplementedError
        else:
            start = start+r
            end   = start+len(w)
            word_offset.append((start,end))
            #print('%32s'%w, '%5d'%start, '%5d'%r, text[start:end])
        start = end

    return word, word_offset


def pedictionstring_to_list(predictionstring):
    return [int(x) for x in predictionstring.split(' ')]



def show_item_in_teminal(start, end, word, word_label, word_start, word_score=None, border=4):
    color = list(discourse_type_color.values())
    start = max(start-border,0)
    end = min(end+border,len(word))

    if word_score is None:
        word_score = [1]*len(word)

    display = ''
    for i in range(start, end):
        bgcolor = color[word_label[i]]

        w = word[i]
        if word_start[i]==1: w =  '\033[1m' + '\033[4m' + w + '\033[0m' + '\033[0m'

        bgcolor = rgbstring( bgcolor, word_score[i] )
        display += bgcolor + ' ' + w + '\x1b[0m'

    return display


def prepare_display_data(predict_df, valid_df):

    unique_id = set(predict_df['id'].unique().tolist() + valid_df['id'].unique().tolist())
    unique_id = list(unique_id)

    display_data = { }
    for id in unique_id:
        valid_df1   = valid_df[valid_df['id']==id].reset_index(drop=True)
        predict_df1 = predict_df[predict_df['id']==id].reset_index(drop=True)

        text_file = '../input/feedback-prize-2021/train/%s.txt'%id
        text = read_text_and_clean(text_file)
        word, word_offset = text_to_word(text)

        word_label_truth   = np.zeros(len(word), np.int8)
        word_start_truth   = np.zeros(len(word), np.int8)
        word_label_predict = np.zeros(len(word), np.int8)
        word_start_predict = np.zeros(len(word), np.int8)
        word_score_predict = np.zeros(len(word), np.float32)

        for i,d in valid_df1.iterrows():
            predictionstring = pedictionstring_to_list(d.predictionstring)
            t = discourse_type_to_label[d.discourse_type]

            for p in predictionstring:
                word_label_truth[p] = t
            word_start_truth[predictionstring[0]]=1

        #---
        for i,d in predict_df1.iterrows():
            predictionstring = pedictionstring_to_list(d.predictionstring)
            t = discourse_type_to_label[d['class']]
            score = eval(d.score)

            for i,p in enumerate(predictionstring):
                word_label_predict[p] = t
                word_score_predict[p] = score[i]
            word_start_predict[predictionstring[0]]=1
        #---
        display_data[id]={
            'word' : word,
            'word_start_truth'   : word_start_truth   ,
            'word_label_truth'   : word_label_truth   ,
            'word_start_predict' : word_start_predict ,
            'word_label_predict' : word_label_predict ,
            'word_score_predict' : word_score_predict ,
        }
    return display_data

In [None]:
#load data here

#ground truth file
# fields : id, discourse_type, predictionstring
truth_file = '../input/demo-data-for-show-heatmap/valid_df.csv'

# predict truth file
# fields : id, class, predictionstring, score (score per word)
# it is better to use one before any post processing (e.g. thresholding to see the errors of your bert model)
predict_file = '../input/demo-data-for-show-heatmap/predict_df.csv'


predict_df = pd.read_csv(predict_file)
valid_df   = pd.read_csv(truth_file)
display_data = prepare_display_data(predict_df, valid_df)

print('load data ok!')


In [None]:

display(HTML("<style>div.output_area pre {white-space: pre;}</style>"))
if 1:
    # discourse_id for demo
    discourse_id_demo = [
        # can be improve by post-processing/learning
        1614906538599,
        1621020033000,

        #other
        1614807435037,
        1615658454847,
        1615776542260,
        1618939965893,
        1618430029590,
        1619400419495,
        1619042045390,
        1615233650061,
        1621032842264,
        1617299648938,
    ]
    for discourse_id in discourse_id_demo:
        d  = valid_df[valid_df['discourse_id']==discourse_id].iloc[0]
        id = d['id']
        predictionstring = pedictionstring_to_list(d.predictionstring)
        start = predictionstring[0]
        end   = predictionstring[-1]


        display1 = show_item_in_teminal(
            start,
            end,
            display_data[id]['word'],
            display_data[id]['word_label_truth'],
            display_data[id]['word_start_truth'],
        )
        display2 = show_item_in_teminal(
            start,
            end,
            display_data[id]['word'],
            display_data[id]['word_label_predict'],
            display_data[id]['word_start_predict'],
            display_data[id]['word_score_predict'],
        )
        
        #format print  
        print( int(d.discourse_id), ':', str((start,end))) 
        display1 = id + ' (T) ' + display1 
        display2 = id + ' (P) ' + display2 
        print(display1)
        print(display2) 
        print('')
        #input('press enter to continue...')
        #input('')




In [None]:

if 0:
    for discourse_type in [
        #'Lead'                 ,
        #'Position'             ,
        #'Claim'                ,
        #'Evidence'             ,
        #'Concluding Statement' ,
        #'Counterclaim'         ,
        'Rebuttal'             ,
    ]:
        valid_df1   = valid_df[valid_df.discourse_type==discourse_type].reset_index(drop=True)
        predict_df1 = predict_df[predict_df['class']==discourse_type].reset_index(drop=True)

        unique_id = set(predict_df['id'].unique().tolist() + valid_df['id'].unique().tolist())
        unique_id = list(unique_id)

        for id in unique_id:
            for t,d in valid_df1[valid_df1['id']==id].reset_index(drop=True).iterrows():


                predictionstring = pedictionstring_to_list(d.predictionstring)
                start = predictionstring[0]
                end   = predictionstring[-1]


                display1 = show_item_in_teminal(
                    start,
                    end,
                    display_data[id]['word'],
                    display_data[id]['word_label_truth'],
                    display_data[id]['word_start_truth'],
                )
                display2 = show_item_in_teminal(
                    start,
                    end,
                    display_data[id]['word'],
                    display_data[id]['word_label_predict'],
                    display_data[id]['word_start_predict'],
                    display_data[id]['word_score_predict'],
                )

                print( int(d.discourse_id), ':', str((start,end)))
                print(id,'(T)', display1)
                print(id,'(P)', display2)
                print('')
                input('press enter to continue...')
                #input('')