In [1]:
import numpy as np
import pandas as pd
import os
import datetime
import tqdm
import pickle
import re
import copy

# 1. load MNL data

In [2]:
# download the MNLI data
! [ ! -d 'data' ] && mkdir data && cd data && wget "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce" && unzip data*
! ls data/MNLI

dev_matched.tsv     original	test_matched.tsv     train.tsv
dev_mismatched.tsv  README.txt	test_mismatched.tsv


In [3]:
matched = []
with open('data/MNLI/dev_matched.tsv', 'r', encoding="utf8") as fi:
    for l in fi:
        matched.append(l.replace("\n", "").split('\t'))

len(matched), matched[0]

(9816,
 ['index',
  'promptID',
  'pairID',
  'genre',
  'sentence1_binary_parse',
  'sentence2_binary_parse',
  'sentence1_parse',
  'sentence2_parse',
  'sentence1',
  'sentence2',
  'label1',
  'label2',
  'label3',
  'label4',
  'label5',
  'gold_label'])

In [4]:
mismatched = []
with open('data/MNLI/dev_mismatched.tsv', 'r', encoding="utf8") as fi:
    for l in fi:
        mismatched.append(l.replace("\n", "").split('\t'))

len(mismatched), mismatched[0]

(9833,
 ['index',
  'promptID',
  'pairID',
  'genre',
  'sentence1_binary_parse',
  'sentence2_binary_parse',
  'sentence1_parse',
  'sentence2_parse',
  'sentence1',
  'sentence2',
  'label1',
  'label2',
  'label3',
  'label4',
  'label5',
  'gold_label'])

In [5]:
! wget https://raw.githubusercontent.com/circulosmeos/gdown.pl/master/gdown.pl && chmod u+x gdown.pl
! ./gdown.pl "https://drive.google.com/file/d/1KWSX8myaF0P4texjPYV2SWb2vvLaBiqu/view?usp=sharing" data_list-new.pkl
! ls -la

--2020-07-29 05:18:39--  https://raw.githubusercontent.com/circulosmeos/gdown.pl/master/gdown.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2072 (2.0K) [text/plain]
Saving to: ‘gdown.pl.1’


2020-07-29 05:18:39 (18.9 MB/s) - ‘gdown.pl.1’ saved [2072/2072]

Cannot open cookies file ‘gdown.cookie.temp’: No such file or directory
--2020-07-29 05:18:40--  https://docs.google.com/uc?id=1KWSX8myaF0P4texjPYV2SWb2vvLaBiqu&export=download
Resolving docs.google.com (docs.google.com)... 172.217.204.101, 172.217.204.139, 172.217.204.102, ...
Connecting to docs.google.com (docs.google.com)|172.217.204.101|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘data_list-new.pkl’

     0K                               

In [6]:
data = pickle.load(open('data_list-new.pkl', 'rb'))
len(data), data[0].keys()

(19647,
 dict_keys(['final_prob', 'answer', 'prediction', 'text_lst', 'text', 'question', 'sentence', 'answer_list', 'reward0', 'reward_norm0', 'reward_norm_dict0', 'reward1', 'reward_norm1', 'reward_norm_dict1', 'reward2', 'reward_norm2', 'reward_norm_dict2']))

## 2.1 Merge with MNLI data

In [7]:
lbl = matched[0]
for i, m in enumerate(matched[1:]):
    d = {}
    for ii, l in enumerate(lbl):
        d[l] = m[ii]
    if len(data) <= i:
        break
    data[i]['src'] = d

In [8]:
lbl = mismatched[0]
for i, m in enumerate(mismatched[1:]):
    d = {}
    for ii, l in enumerate(lbl):
        d[l] = m[ii]
    if len(data) <= i:
        break
    data[i + 9815]['src'] = d

In [9]:
i = -1
data[i]['src'], '||||', data[i]['question'], data[i]['sentence'], len(data[i]['reward0'][0]), data[i].keys()


({'genre': 'verbatim',
  'gold_label': 'entailment',
  'index': '9831',
  'label1': 'entailment',
  'label2': 'entailment',
  'label3': 'entailment',
  'label4': 'entailment',
  'label5': 'entailment',
  'pairID': '8693e',
  'promptID': '8693',
  'sentence1': "Bloomer (for `flower'), butter (for `ram'), or even flower (for `river') are recurrent examples, but solvers must always be on the alert for new traps of this ",
  'sentence1_binary_parse': "( ( ( ( ( ( ( ( ( ( ( Bloomer ( -LRB- ( ( for ( ` ( flower ' ) ) ) -RRB- ) ) ) , ) ( butter ( -LRB- ( ( for ( ` ( ram ' ) ) ) -RRB- ) ) ) ) , ) or ) ( even flower ) ) ( -LRB- ( ( for ( ` ( river ' ) ) ) -RRB- ) ) ) ( are ( recurrent examples ) ) ) , ) but ) ( solvers ( ( must always ) ( be ( on ( ( the alert ) ( for ( ( new traps ) ( of this ) ) ) ) ) ) ) ) )",
  'sentence1_parse': "(ROOT (FRAG (S (S (NP (NP (NP (NP (NNP Bloomer)) (PRN (-LRB- -LRB-) (PP (IN for) (NP (`` `) (NN flower) ('' '))) (-RRB- -RRB-))) (, ,) (NP (NP (NN butter)) (PRN (

# 2.2 Reconnect splitted words

In [10]:
def add_space_to_punc(txt):
    punc = './:{}()'
    for p in punc:
        txt = txt.replace(p, ' ' + p + ' ')
    txt = txt.replace("'s", " 's ").replace("n't", "n ' t ")
    return txt

In [11]:
debug = False

for i in tqdm.tqdm(range(len(data))):
# for i in tqdm.tqdm(range(20, 21)):
    question = add_space_to_punc(data[i]['src']['sentence1']).split()
    sentence = add_space_to_punc(data[i]['src']['sentence2']).split()
    question_id = []
    sentence_id = []
    cnt = 0
    is_q = True
    ii = -1

    while ii < len(data[i]['text_lst']) - 1:
        ii += 1
        w = data[i]['text_lst'][ii]
        if w == '[CLS]':
            src = question
            target = question_id
            cnt = 0
            continue

        if debug is True:
            print(w, ii)

        if w == '[SEP]':
            is_q = False
            src = sentence
            target = sentence_id

            if debug is True:
                print('-' * 70)

            cnt = 0
        elif w == src[cnt].lower().strip():
            if debug is True:
                target.append([ii, w])
            else:
                target.append([ii])

            cnt +=1
        elif w == src[cnt].lower().strip()[:len(w)]:
            if debug is True:
                print(w, src[cnt])

            target.append([])
            for iii in range(1, 8):
                w_new = ''.join(data[i]['text_lst'][ii:ii+iii]).replace("#", "")
                if w_new != src[cnt].lower().strip()[:len(w_new)]:
                    ii += (iii - 2)
                    cnt += 1
                    if debug is True:
                        print(ii, cnt)
                        if ii < len(data[i]['text_lst']) and cnt < len(src):
                            print("===>", w_new, data[i]['text_lst'][ii] , src[cnt].lower().strip())

                    break
                else:
                    if debug is True:
                        target[-1] += [ii + iii - 1, w_new]
                    else:
                        target[-1] += [ii + iii - 1]
    
    data[i]['question_ids'] = question_id
    data[i]['sentence_ids'] = sentence_id
    data[i]['question_list'] = question
    data[i]['sentence_list'] = sentence
    

100%|██████████| 19647/19647 [00:02<00:00, 7174.18it/s]


## 2.3 Extract POS

In [12]:
debug = False

for i in tqdm.tqdm(range(len(data))):
#     print(i)
    for j in range(2):
        if j == 0:
            txt = data[i]['src']['sentence1_parse']
            txt_list = data[i]['question_list']
            data[i]['question_pos'] = []
            target = data[i]['question_pos']
        else:
            txt = data[i]['src']['sentence2_parse']
            txt_list = data[i]['answer_list']
            data[i]['answer_pos'] = []
            target = data[i]['answer_pos']

        pos_list = []
        st = ''
        for t in txt:
            if t == '(':
                st = ''
            elif t == ')':
                if len(st) > 0:
                    st_1 =  re.sub("[^0-9a-z]", "", st.lower().split()[-1])
                    if st[0] != '-' and len(st_1) > 0:
                        pos_list.append(st)
                st = ''
            else:
                st += t
        pos_list2 = [p.split() for p in pos_list]
        sentence1_pos = ['']* len(txt_list)

        iii = 0
        for ii, w in enumerate(txt_list):
            w_new =  re.sub("[^0-9a-z]", "", w.lower())
            if debug is True:
                print(ii, iii, w, w_new)

            if len(w_new) > 0:
                pos_w = re.sub("[^0-9a-z]", "", pos_list2[iii][-1].lower())
                if debug is True:
                    print('===>', w_new, pos_w)

                if w_new == pos_w:
                    sentence1_pos[ii] = pos_list2[iii][0]
                    iii += 1
        target = sentence1_pos

100%|██████████| 19647/19647 [00:04<00:00, 4158.24it/s]


## 2.3 Recalculate rewards

In [13]:
debug = False

for i in tqdm.tqdm(range(len(data))):
    if debug is True:
        print(i)

    data[i]['question_id_reward0'] = []
    data[i]['question_id_reward1'] = []
    data[i]['question_id_reward2'] = []
    
    for l1 in data[i]['question_ids']:
        data[i]['question_id_reward0'].append(sum([data[i]['reward0'][0][l] for l in l1]))
        data[i]['question_id_reward1'].append(sum([data[i]['reward1'][0][l] for l in l1]))
        data[i]['question_id_reward2'].append(sum([data[i]['reward2'][0][l] for l in l1]))

    data[i]['sentence_id_reward0'] = []
    data[i]['sentence_id_reward1'] = []
    data[i]['sentence_id_reward2'] = []
    
    for l1 in data[i]['sentence_ids']:
        data[i]['sentence_id_reward0'].append(sum([data[i]['reward0'][0][l] for l in l1]))
        data[i]['sentence_id_reward1'].append(sum([data[i]['reward1'][0][l] for l in l1]))
        data[i]['sentence_id_reward2'].append(sum([data[i]['reward2'][0][l] for l in l1]))

        

100%|██████████| 19647/19647 [00:01<00:00, 12119.22it/s]


# 3. Visualization

In [14]:
# JS Code
js_code = """


requirejs(['jquery', 'd3'], function($, d3) {

    const TEXT_SIZE = 15;
    const BOXWIDTH = 110;
    const BOXHEIGHT = 22.5;
    const MATRIX_WIDTH = 150;
    const CHECKBOX_SIZE = 20;
    const TEXT_TOP = 50;
    const HEAD_COLORS = d3.scaleOrdinal(d3.schemeCategory10);

    var data_list = window.data_listNNAME;
    var config = {};
    var questionID = 0;

    function renderText(svg, text, isLeft, leftPos, needHeader=false, header="NOTHING", shift_x=0) {
        var id = isLeft ? "left" : "right";
        var textContainer = svg.append("svg:g").attr("id", id);

        var tokenContainer = textContainer.append("g").selectAll("g")
                                        .data(text)
                                        .enter()
                                        .append("g");

        tokenContainer.append("rect")
                    .classed("background", true)
                    .style("opacity", 0.0)
                    .attr("fill", "lightgray")
                    .attr("x", leftPos + shift_x)
                    .attr("y", function(d, i) {
                    return TEXT_TOP + i * BOXHEIGHT;
                    })
                    .attr("width", BOXWIDTH)
                    .attr("height", BOXHEIGHT);

        var textEl = tokenContainer.append("text")
                                    .text(function(d) { return d; })
                                    .attr("font-size", TEXT_SIZE + "px")
                                    .style("cursor", "default")
                                    .style("-webkit-user-select", "none")
                                    .attr("x", leftPos + shift_x)
                                    .attr("y", function(d, i) {
                                      return TEXT_TOP + i * BOXHEIGHT;
                                    });

        if (isLeft) {
            textEl.style("text-anchor", "end")
                    .attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE)
                    .attr("dy", TEXT_SIZE);
        } else {
            textEl.style("text-anchor", "start")
                    .attr("dx", + 0.5 * TEXT_SIZE)
                    .attr("dy", TEXT_SIZE);
        }

        if (needHeader){
            var head_x = 55 + shift_x;

            textContainer.append("g")
                        .append('text')
                        .text(header)
                        .attr("font-size", "20px")
                        .attr("fill", "black")
                        .style("cursor", "default")
                        .style("-webkit-user-select", "none")
                        .attr("x", head_x)
                        .attr("y", TEXT_TOP - 20);
        }
    }

    function renderDataBox(svg, data, id, col=d3.interpolateRdYlGn, shift_x=0) {
        var boxContainer = svg.append("svg:g").attr("id", id);
        var tinyBoxWidth = MATRIX_WIDTH / config.reward0.length;
        var Tooltip = d3.select("#visNNAME")
                    .append("div")
                    .style("opacity", 0)
                    .attr("class", "tooltip")
                    .style("background-color", "white")
                    .style("border", "solid")
                    .style("border-width", "2px")
                    .style("border-radius", "5px")
                    .style("padding", "5px");

        boxContainer.append("g").classed("dataBoxes", true)
                    .selectAll("g")
                    .data(data)
                    .enter()
                    .append("rect")
                    .attr("x", function(d) {
                        return shift_x + BOXWIDTH + d.l * tinyBoxWidth;
                    })
                    .attr("y", function(d) {
                        return TEXT_TOP + (d.t) * BOXHEIGHT;
                    })
                    .attr("width", tinyBoxWidth)
                    .attr("height", function() { return BOXHEIGHT; })
                    .attr("fill", function(d) {
                        return col(d.v*0.5 + 0.5);
                    })
                    .style("opacity", 100.0)
                    .on("mouseover", function(d) {
                        Tooltip.style("opacity", 1);
                        d3.select(this)
                            .style("stroke", "black");
                    })
                    .on("mousemove", function(d) {
                        var htm = "Layer Name: " + d.layer_name + "</br> ";
                        htm += "Reward Value: " + d.v_r + "</br> ";
                        htm += "Normalized Reward Value: " + d.v;

                        Tooltip
                            .html(htm)
                            .style("left", (d3.mouse(this)[0]+70) + "px")
                            .style("top", (d3.mouse(this)[1] + 110) + "px");
                    })
                    .on("mouseout", function(d) {
                        Tooltip.style("opacity", 0);
                        d3.select(this)
                            .style("stroke", "none");
                    });



        var lst = [];
        for (var i = 1; i <= config.reward_norm0.length; i++) {
            lst.push(i);
        }

        console.log("lst:", lst, "!!!", tinyBoxWidth);
        boxContainer.append("g").classed("layer_num", true)
                    .selectAll("g")
                    .data(lst)
                    .enter()
                    .append("text")
                    .text(function(d) { return d; })
                    .attr("font-size", "10px")
                    .attr("fill", "brown")
                    .style("cursor", "default")
                    .style("-webkit-user-select", "none")
                    .attr("x", function(d) {
                        return shift_x + BOXWIDTH + Math.round((d-1) * (tinyBoxWidth-0.15)) + 3;
                    })
                    .attr("y", TEXT_TOP - 3);
    }

    function init(){
        config = data_list[questionID];

        console.log('run init function');
        console.log("text", config.text);
        console.log("question", config.question);
        console.log("answer", config.answer);

        $("#sentenceNNAME").text(config.sentence);
        $("#answerNNAME").text(config.answer);
        $("#predictionNNAME").text(config.prediction);
        $("#final_probNNAME").text(config.final_prob);
    }

    function render() {
        init()
        var leftText = config.text_lst;
        var rightText = config.text_lst;

        $("#visNNAME svg").empty();
        $("#visNNAME").empty();

        var height = (config.reward0[0].length + 2) * BOXHEIGHT + TEXT_TOP;
        console.log(height);
        console.log(config.reward0[0].length);
        var svg = d3.select("#visNNAME")
            .append('svg')
            .attr("width", "100%")
            .attr("height", height + "px");

        // Visualize class 0
        renderText(svg, leftText, true, 0, true, config.answer_list[0]);
        renderDataBox(svg, config.reward_norm_dict0, "reward_norm_dict0");

        // // Visualize class 1
        renderText(svg, leftText, true, 0, true, config.answer_list[1], 260);
        renderDataBox(svg, config.reward_norm_dict1, "reward_norm_dict1", d3.interpolateRdYlGn, 260);
        // renderDataBox(svg, config.not_entail_reward_norm_dict, "entail_data", d3.interpolateReds, 450);

        // // Visualize class 2
        renderText(svg, leftText, true, 0, true, config.answer_list[2], 530);
        renderText(svg, rightText, false, MATRIX_WIDTH + BOXWIDTH, false, "", 530);
        renderDataBox(svg, config.reward_norm_dict2, "reward_norm_dict2", d3.interpolateRdYlGn, 530);
    }

    $("#questionsNNAME").on('change', function (e) {
        questionID = e.currentTarget.value;
        render();
    });

    // Render the view
    render();

});

"""

In [25]:
# prepare scores
def prep_data_list(s, e):
    answer_list = ['Neutral', 'Entailment', 'Contradiction',]
    data_list = []
    for i in range(s, e):
        d = copy.deepcopy(data[i])
        
        for i in range(len(d['text_lst'])):
            if d['text_lst'][i] in ['[CLS]', '[SEP]']:
                for ii in range(len(d['reward0'])):
                    for iii in range(len(answer_list)):
                        d['reward'+ str(iii)][ii][i] = 0

        for i in range(len(d['reward0'])):
            for ii in range(len(answer_list)):
                r = d['reward'+ str(ii)][i]
                np_min_r = np.min(r) if np.min(r) < 0 else -0.00000001
                np_max_r = np.max(r) if np.max(r) > 0 else +0.00000001
                d['reward_norm'+ str(ii)][i] = np.array([a/np_max_r if a > 0 else a/abs(np_min_r) for a in r]).tolist()
                d['reward_norm_dict' + str(ii)] = []

        for ii in range(len(answer_list)):
            d['reward_norm_dict' + str(ii)] = []
            for i in range(len(d['reward' + str(ii)])):
                for j in range(len(d['reward' + str(ii)][i])):
                    layer_name=""
                    if i == 0:
                        layer_name = "Input"
                    elif i == 1:
                        layer_name = "BertEmbeddings"
                    elif i < 14:
                        layer_name = "BertLayer" + str(i-1)
                    else:
                        layer_name = "mnli_mdl.classifier"
                    
                    d['reward_norm_dict' + str(ii)].append({"l": i, "t":j, 
                                                                   "v":d['reward_norm' + str(ii)][i][j],
                                                                   "v_r": d['reward' + str(ii)][i][j],
                                                                   "layer_name": layer_name
                                                                  })
        data_list.append(d)
    return data_list


In [26]:
import json
from IPython.core.display import display, HTML, Javascript
import os


def model_view(data_list, nname="01"):
    vis_html = """
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        <div class="panel panel-default">
        <div class="panel-heading">
            <strong> Question: </strong> <select id="questionsNNAME">
    """.replace("NNAME", nname)
    for i, d in enumerate(data_list):
        vis_html += '<option value="' + str(i) + '">' + d['question'][:100] + "</option> \n"
    vis_html += """
        </select>
        </div>
        <div class="panel-body">
            <div class="well">
                <strong> Sentence: </strong> <span id="sentenceNNAME"> TTT</span> </br>
                <strong> Answer: </strong> <span id="answerNNAME"> TTT</span> </br>
                <strong> Prediction: </strong> <span id="predictionNNAME"> TTT</span> </br>
                <strong> Final Prob: </strong> <span id="final_probNNAME"> TTT</span> </br>

            </div>
        </div>
        <div id='visNNAME'></div>
    </div>

    """.replace("NNAME", nname)


    display(HTML(vis_html))
    vis_js = js_code.replace("NNAME", nname)
    display(Javascript(('window.data_listNNAME = %s' % json.dumps(data_list)).replace("NNAME", nname)))
    display(Javascript(vis_js))

In [27]:
from importlib import reload  
data_list = prep_data_list(0, 20)
model_view(data_list)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>