In [1]:
import pickle
import numpy as np
from re import split
import matplotlib.pyplot as plt

In [2]:
data = pickle.load(open("data/emotype_v12.p", "rb" ))

## First, we'll explore the idea of color gradients in matplotlib

In [3]:
# Input goes from 0 to 255, with 0 as 
# Checkout https://matplotlib.org/users/colormaps.html for other color maps
# e.g. bg_blue_to_red_grad = plt.cm.get_cmap(name='bwr') for blue to red gradient
bg_white_to_red_grad = plt.cm.get_cmap(name='Reds')
print(bg_white_to_red_grad(0))   # Prints (R, G, B, alpha_transparency)
print(bg_white_to_red_grad(1))
# ...
print(bg_white_to_red_grad(254))
print(bg_white_to_red_grad(255))
print(bg_white_to_red_grad(256))  # Note that there's no change from 255 to 256

(1.0, 0.9607843137254902, 0.9411764705882353, 1.0)
(0.9998769703960015, 0.9582006920415225, 0.9374855824682814, 1.0)
(0.4115494040753557, 0.0018454440599769348, 0.05196462898885043, 1.0)
(0.403921568627451, 0.0, 0.05098039215686274, 1.0)
(0.403921568627451, 0.0, 0.05098039215686274, 1.0)


## Note that we can extract the (R, G, B) values for a specific point along the color gradient using the following function:

In [4]:
# We can extract the red, green, and blue values specifically
# e.g. for the last value in the gradient
def turn_attn_weight_into_color(weight, tmp_cmap):
    color_indx = int(255 * weight)
    r, g, b, alpha = [int(255 * tmp_cmap(color_indx)[i]) for i in range(4)]
    return(r, g, b)

In [5]:
r, g, b = turn_attn_weight_into_color(0.5, bg_white_to_red_grad)
print(r, g, b)

251 106 74


## Finally, we just have to work some magic with ANSI Escape codes and voila!

In [6]:
# https://en.wikipedia.org/wiki/ANSI_escape_code
# See section on "Colors", specifically "24-bit"
# ESC[ … 38;2;<r>;<g>;<b> … m Select RGB foreground color
# ESC[ … 48;2;<r>;<g>;<b> … m Select RGB background color

def get_word_with_rgb(word, r, g, b):
    esc = "\x1b["
    txt_style = "2;"  # Text style ("1" is bold, "2" is not bold)
    toggle_bg = "48;2;"  # Switch to toggle background color (38;2; is foreground color)
    r = str(r)  # Red
    g = str(g)  # Green
    b = str(b)  # Blue
    ansi_code = esc + txt_style + toggle_bg + r + ";" + g + ";" + b + "m"
    return("%s%s" % (ansi_code, word))

In [7]:
get_word_with_rgb("hello", r, g, b)

'\x1b[2;48;2;251;106;74mhello'

In [8]:
print(get_word_with_rgb("hello", r, g, b))

[2;48;2;251;106;74mhello


## Now, turning attention weights into highlights

### NOTE: There's a problem where the number of text tokens doesn't match the length of attention weights vector. That's causing some contextual issues where the colored attention doesn't exactly match the word it's ascribed to.

In [9]:
data[0]['attention_weights'].shape

(1, 1, 28)

In [10]:
len(split(' ', data[0]['text_tokens']))

32

In [11]:
data[2]['attention_weights'].shape

(1, 1, 23)

In [12]:
len(split(' ', data[2]['text_tokens']))

24

### Proceeding regardless for the purpose of demonstration...

In [15]:
data[0].keys()

dict_keys(['text', 'attention_weights', 'prediction', 'label', 'outputs', 'text_tokens', 'encoding'])

In [20]:
my_cmap = plt.cm.get_cmap(name='Reds')

for i in range(5):
    op_txt = ''
    true_label = data[i]['label']
    pred_label = data[i]['prediction']
    tokens = split(' ', data[i]['text_tokens'])
    attn_weights = list(data[i]['attention_weights'][0][0])
    for word_indx, word in enumerate(tokens):
        # TODO! Early stopping because of index mismatch described above.
        if word_indx >= len(tokens) or word_indx >= len(attn_weights):  
            continue
        r, g, b = turn_attn_weight_into_color(attn_weights[word_indx], my_cmap)
        op_txt += get_word_with_rgb(word, r, g, b) + ' '
    # Note that we have to explicitly set the normal background to plain ol' white
    print("\x1b[2;48;2;255;255;255mTrue Label = {}\nPredicted Label = {}".format(true_label, pred_label))
    print(op_txt + '\n')

[2;48;2;255;255;255mTrue Label = anxiety
Predicted Label = anxiety
[2;48;2;255;245;240mhes [2;48;2;255;245;240mso [2;48;2;254;237;229mfucking [2;48;2;255;245;240mcool [2;48;2;255;245;240mand [2;48;2;255;245;240mwell [2;48;2;255;245;240mrespected [2;48;2;255;245;240m. [2;48;2;254;244;239meverybody [2;48;2;255;245;240mloves [2;48;2;255;245;240mhim [2;48;2;255;245;240mand [2;48;2;254;241;234mhes [2;48;2;254;243;238msuper [2;48;2;254;243;238mpopular [2;48;2;255;245;240m. [2;48;2;254;244;239mim [2;48;2;255;245;240ma [2;48;2;255;245;240mnobody [2;48;2;255;245;240m. [2;48;2;254;244;239mi [2;48;2;252;168;139malready [2;48;2;255;245;240mhave [2;48;2;254;244;239msevere [2;48;2;254;244;239manxiety [2;48;2;255;245;240m. [2;48;2;254;232;222mhe [2;48;2;251;112;80mbrings 

[2;48;2;255;255;255mTrue Label = anxiety
Predicted Label = anxiety
[2;48;2;255;245;240mi [2;48;2;255;245;240mweirdly [2;48;2;255;245;240mget [2;48;2;255;245;240mreally [2;48;2;254;237;228manxious 