In [59]:
import json
import math
import os.path
import sys
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from IPython.display import HTML, Image, clear_output, display
from torch.autograd import Variable
from tqdm import tqdm

sys.path.append('..')

In [2]:
import config
import data
import model_IG
import utils

  from ._conv import register_converters as _register_converters


In [3]:
%reload_ext autoreload
%autoreload 2

## Paths, parameters, etc.

In [4]:
# which GPU device to use?
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

# path to pretrained model
MODEL_FILE = '../logs/2017-08-04_00.55.19.pth'

# TSV file to write attributions 
ATTRS_TSV = '/scratch/pramodkm/acl18/vqa/tsv/attrs.tsv'

# HTML file to pretty display attributions
# The folder containing the images is assumed to be named "val" 
# and be in the same directory as this HTML
ATTRS_HTML = '/scratch/pramodkm/acl18/vqa/attrs.html'

# Number of steps in Riemann integral computation for Integrated Gradients
NUM_STEPS = 2000

# Sample size of dataset to use in all computations of this notebooks
MAX_NUM_BATCHES = 10000 

# File (EPS format) for writing the overstability curve
OVERSTABILITY_CURVE_FILE = '/scratch/pramodkm/acl18/vqa/overstability.eps'

In [5]:
# Load pretrained model
log = torch.load(MODEL_FILE)
tokens = len(log['vocab']['question']) + 1

net = torch.nn.DataParallel(model_IG.Net(tokens))
net.load_state_dict(log['weights'])

  init.xavier_uniform(w)
  init.xavier_uniform(self.embedding.weight)
  init.xavier_uniform(m.weight)


In [6]:
# Load vocabulary
with open(config.vocabulary_path, 'r') as fd:
    vocab_json = json.load(fd)
reverse_vocab_question = dict(
    [(v, k) for k, v in vocab_json['question'].items()])
reverse_vocab_answer = dict([(v, k) for k, v in vocab_json['answer'].items()])

In [7]:
# Extract embedding matrix for question words
question_emb_lookup = log['weights']['module.text.embedding.weight']
embedding = nn.Embedding(
    question_emb_lookup.shape[0], question_emb_lookup.shape[1], padding_idx=0)
embedding.weight.data = question_emb_lookup

In [8]:
# Load validation data
val_loader = data.get_loader(val=True)
LOADER = val_loader
PREFIX = "val"

In [9]:
# Image index dict
reverse_coco_idxs = dict(
    [(v, k) for k, v in val_loader.dataset.coco_id_to_index.items()])

In [10]:
def get_answer(a):
    """ Get the answer that at least 3 turkers have agreed on """
    indices = torch.nonzero(a >= 3)
    if len(indices) == 0:
        return ''
    return '|'.join([reverse_vocab_answer[int(index)] for index in indices])

In [11]:
# Setting up a few items
if not os.path.isdir(os.path.dirname(ATTRS_TSV)):
    os.makedirs(os.path.dirname(ATTRS_TSV))

var_params = {
    'requires_grad': False,
}

# number of batches for the integral summation for computing attributions
num_batches_ig = int(np.ceil(NUM_STEPS/val_loader.batch_size))

## Compute accuracy

In [40]:
# Script to compute attributions for a fixed number of batches
net.eval()
accs = []
num_batches = 0
# iterator over the validation dataset
tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
for v, q, a, idx, q_len in tq:
    v = Variable(v.cuda(async=True), **var_params)
    q = Variable(q.cuda(async=True), **var_params)
    a = Variable(a.cuda(async=True), **var_params)
    q_len = Variable(q_len.cuda(async=True), **var_params)

    q_emb = embedding(q)

    out = net(v, q_emb, q_len)

    acc = utils.batch_accuracy(out.data, a.data).cpu()

    accs.append(np.array(acc.view(-1)))

    del v, q, a, idx, q_len, q_emb, acc, out

    if num_batches >= MAX_NUM_BATCHES:
        break
    num_batches += 1

accs = list(np.concatenate(accs, axis=0))


  
  if __name__ == '__main__':
  # Remove the CWD from sys.path while we load stuff.
  # This is added back by InteractiveShellApp.init_path()
  attention = F.softmax(attention)

val E000:   0% 1/950 [00:01<22:28,  1.42s/it][A
val E000:   0% 2/950 [00:01<12:11,  1.30it/s][A
val E000:   0% 3/950 [00:01<08:59,  1.76it/s][A
val E000:   1% 5/950 [00:01<05:51,  2.69it/s][A
val E000:   1% 6/950 [00:02<05:16,  2.99it/s][A
val E000:   1% 7/950 [00:02<05:08,  3.06it/s][A
val E000:   1% 9/950 [00:02<04:26,  3.54it/s][A
val E000:   1% 10/950 [00:02<04:26,  3.53it/s][A
val E000:   1% 12/950 [00:02<03:49,  4.08it/s][A
val E000:   2% 15/950 [00:03<03:14,  4.80it/s][A
val E000:   2% 17/950 [00:03<03:18,  4.69it/s][A
val E000:   2% 20/950 [00:03<02:54,  5.32it/s][A
val E000:   2% 22/950 [00:03<02:43,  5.67it/s][A
val E000:   3% 24/950 [00:03<02:34,  6.01it/s][A
val E000:   3% 26/950 [00:04<02:28,  6.24it/s][A
val E000:   3% 28/950 [00:04<02:26,  6.30it/s][A
val E000:   3% 30/950 [00:

In [41]:
# Accuracy 
print('Accuracy over',len(accs),'inputs:',np.mean(np.array(accs) >= 0.3))
print('Mean turker agreement over',len(accs),'inputs:',np.mean(np.array(accs)))

Accuracy over 121512 inputs: 0.6734972677595629
Mean turker agreement over 121512 inputs: 0.6112985


## Compute attributions
Approx running time: 21 hours on GeForce GTX 1080 Ti (12 GB), Intel(R) Xeon(R) Silver 4110 CPU @ 2.10GHz, 64GB RAM

In [12]:
padding_embedding = embedding.weight.data[0, :]


def scale_input(q_emb, num_batches=1):
    """ Create scaled versions of input and stack along batch dimension
    q_emb shape = (q_length, emb_dim)
    """
    num_points = config.batch_size*num_batches
    scale = 1.0/num_points
    step = (q_emb.unsqueeze(0) -
            padding_embedding.unsqueeze(0).unsqueeze(0)) * scale
    ans = torch.cat([torch.add(padding_embedding, step*i)
                     for i in range(num_points)], dim=0)
    return ans, step.squeeze()

In [33]:
def compute_attributions(q_emb, q_len, v, idx, num_batches=5, answer=None):
    """ compute attributions for all examples in a given batch """
    ans = ''
    for batch_i in range(int(q_emb.shape[0])):
        scaled_q_emb, step = scale_input(
            q_emb[batch_i, :, :], num_batches=num_batches)
        diff = 0
        total_grads = 0
        repeated_q_len = (torch.ones([config.batch_size] + list(q_len.shape[1:]), dtype=torch.long).cuda(async=True)*q_len[batch_i])
        repeated_v = (torch.ones([config.batch_size] + list(v.shape[1:])).cuda(async=True)*v[batch_i])
        for j in range(num_batches):
            batch_scaled_q_emb = scaled_q_emb[j*config.batch_size:(
                j+1)*config.batch_size]
            with torch.autograd.set_grad_enabled(True):
                scaled_answer, gradients = net(
                    repeated_v, batch_scaled_q_emb, repeated_q_len, compute_gradient=True, ans_index=int(answer[batch_i]))
            # at this point, shape(gradients) = 128 x 23 x 300
            total_grads += torch.sum(gradients, dim=0)
            if j == 0:
                diff -= scaled_answer[0, answer[batch_i]]
                baseline_softmax = scaled_answer[0, :]
            if j == num_batches - 1:
                diff += scaled_answer[-1, answer[batch_i]]
        del scaled_q_emb, repeated_q_len, repeated_v, batch_scaled_q_emb, gradients
        attributions = torch.sum(total_grads * step, dim=1)
        area = torch.sum(attributions, dim=0)
        #print('--------------------------')
        #print(('diff: ', float(diff)))
        #print(('area: ', float(area)))
        if abs(float(diff) - float(area)) > 0.001:
            print(('WARNING: attribution sanity check not matching up!! Diff = ', abs(
                float(diff) - float(area))))

        predicted_answer = reverse_vocab_answer[int(answer[batch_i])]
        correct_answer = get_answer(a[batch_i, :])
        _, baseline_topk_answers = baseline_softmax.topk(1)
        baseline_topk_answers = ', '.join(
            [reverse_vocab_answer[int(i)] for i in baseline_topk_answers])

        if baseline_topk_answers[0] == predicted_answer:
            attributions = attributions*0

        #print('Predicted answer: ', predicted_answer)
        #print('Baseline top k answers : ', ' | '.join(baseline_topk_answers))
        #print('Prediction is correct?: ', int(acc[batch_i]))
        #print('Image ID: ', val_loader.dataset.coco_ids[int(idx[batch_i])])
        question_attrs = []
        for j, w in enumerate(q[batch_i, :]):
            if int(w) != 0:
                #print(reverse_vocab_question[int(w)], ': ', float(attributions[j]))
                question_attrs.append(
                    '|'.join([str(reverse_vocab_question[int(w)]), str(float(attributions[j]))]))
        tsv_string = ['||'.join(question_attrs), baseline_topk_answers, predicted_answer, correct_answer, str(
            int(acc[batch_i])), str(val_loader.dataset.coco_ids[int(idx[batch_i])])]
        ans += '\t'.join(tsv_string) + '\n'
        del attributions, area
    return ans

In [None]:
# Script to compute attributions for a fixed number of batches
net.train()
accs = []
num_batches = 0
with open(ATTRS_TSV, 'a') as outf:
    # iterator over the validation dataset
    tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
    for v, q, a, idx, q_len in tq:
        if num_batches < 949:
            num_batches += 1
            continue

        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)

        q_emb = embedding(q)

        out = net(v, q_emb, q_len)

        acc = utils.batch_accuracy(out.data, a.data).cpu()

        _, answer = out.data.cpu().max(dim=1)

        attrs_tsv_string = compute_attributions(
            q_emb, q_len, v, idx, num_batches=num_batches_ig, answer=answer)

        outf.write(attrs_tsv_string)
        outf.flush()

        accs.append(np.array(acc.view(-1)))
        
        del v, q, a, idx, q_len, q_emb, acc, out, answer

        if num_batches >= MAX_NUM_BATCHES:
            break
        num_batches += 1

accs = list(np.concatenate(accs, axis=0))

## Visualization

In [None]:
def visualize_attrs(tokens, attrs):
    html_text = ""
    for i, tok in enumerate(tokens):
        r, g, b = get_color(attrs[i])
        html_text += " <strong><span style='size:16;color:rgb(%d,%d,%d)'>%s</span></strong>" % (
            r, g, b, tok)
    return html_text


def get_latex(tokens, attrs):
    ans = ""
    for i, tok in enumerate(tokens):
        [r, g, b] = [w/256.0 for w in get_color(attrs[i])]
        ans += " {\color[rgb]{%f,%f,%f}%s}" % (r, g, b, tok)
    return ans


def normalize_attrs(attrs):
    """ normalize attributions to between -1 and 1 """
    bound = max(abs(attrs.max()), abs(attrs.min()))
    return attrs/bound


def get_color(attr):
    """ attr is assumed to be between -1 and 1 """
    if attr > 0:
        return int(128*attr) + 127, 128 - int(64*attr), 128 - int(64*attr)
    return 128 + int(64*attr), 128 + int(64*attr), int(-128*attr) + 127

In [None]:
def make_visualization_html(tsv_filename, html_filename):
    html_str = '<head><link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"></head>'
    html_str += '<body> <div class="container"> <h3> Visualizations of the attributions for the Visual QA network <br> <small> Red indicates high values, blue and gray indicates low values <br> A green (or red) block before the question indicates whether the network got the answer right (or wrong)</small></h3></div><br>'
    with open(tsv_filename) as f, open(html_filename, 'w') as outf:
        html_str += '<div class="container">'
        html_str += '-'*40 + '<br>'
        outf.write(html_str)
        for line in f:
            line = line.strip()
            question_attrs, baseline_topk_answers, predicted_answer, correct_answer, is_correct, image_id = line.split(
                '\t')
            question_tokens = []
            attrs = []
            for word_attr in question_attrs.split(','):
                word, attr = word_attr.split('|')
                question_tokens.append(word)
                attrs.append(float(attr))
            html_str = visualize_attrs(
                question_tokens, normalize_attrs(np.array(attrs)))

            if is_correct == '1':
                html_str = '<span style="background-color:green">&nbsp&nbsp</span> ' + html_str
            else:
                html_str = '<span style="background-color:red">&nbsp&nbsp</span> ' + html_str
            html_str += '<br>(prediction, ground truth) = (' + predicted_answer + ', ' + correct_answer + ')'
            html_str += '<br>prediction :' + predicted_answer
            html_str += '<br>baseline topk answers: ' + baseline_topk_answers
            html_str += '<br>image ID: ' + str(image_id)
            html_str += '<br><img src="val2014/COCO_val2014_' + '0' * \
                (12 - len(str(image_id))) + str(image_id) + \
                '.jpg" width="256" height="256"></img><br><br>'
            #display(Image('/scratch/pramodkm/vqa/data_vqa1.0/val2014/COCO_val2014_' + '0'*(12 - len(str(image_id))) + str(image_id) + '.jpg', width=256, height=256))
            outf.write(html_str + '\n')
        outf.write('</div></body>')

In [None]:
make_visualization_html(tsv_filename = ATTRS_TSV, 
                       html_filename = ATTRS_HTML)

## Attack by prefixing sentences

In [45]:
def question_concatenation_accuracy(net, phrase, suffix=False):
    """ compute accuracy when phrases are prefixed/suffixed """
    net.eval()
    
    prefix = []
    for word in phrase.split():
        prefix.append(vocab_json['question'][word])
    prefix = torch.LongTensor(prefix) 
    prefix = prefix.unsqueeze(0).repeat(config.batch_size,1)
    accs = []
    num_batches = 0
    # iterator over the validation dataset
    tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
    for v, q, a, idx, q_len in tq:
        curr_batch_size = int(q.shape[0])
        if not suffix:
            q = torch.cat([prefix[:curr_batch_size], q], dim=1)[:,:23] ## 23 because question_length is configured to cap at 23
        else:
            for i in range(curr_batch_size):
                nnz_ix = int((torch.nonzero(q[i,:].cpu())).squeeze().max()) + 1
                if nnz_ix >= q.shape[1]:
                    continue
                q[i, nnz_ix:] = prefix[0, :23-nnz_ix]
        q = q.contiguous()
        
        q_len = q_len + prefix.shape[1]
        q_len = torch.min(q_len, torch.LongTensor([23]).expand_as(q_len))


        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)

        q_emb = embedding(q)

        out = net(v, q_emb, q_len)

        acc = utils.batch_accuracy(out.data, a.data).cpu()

        accs.append(np.array(acc.view(-1)))
        del v, q, a, idx, q_len, q_emb, acc, out

        if num_batches >= MAX_NUM_BATCHES:
            break
        num_batches += 1

    accs = list(np.concatenate(accs, axis=0))
    return np.mean(np.array(accs) >= 0.3)

In [46]:
PHRASES = [
    'in not a lot of words',
    'in not many words',
    'what is the answer to',
    'tell me',
    'answer this',
    'answer this for me'
]

prefix_attack_accs = []
for phrase in PHRASES:
    prefix_attack_accs.append(question_concatenation_accuracy(net, phrase))


  attention = F.softmax(attention)

val E000:   0% 1/950 [00:02<32:01,  2.03s/it][A
val E000:   0% 3/950 [00:02<11:11,  1.41it/s][A
val E000:   1% 5/950 [00:02<07:03,  2.23it/s][A
val E000:   1% 7/950 [00:02<05:16,  2.98it/s][A
val E000:   1% 9/950 [00:02<04:57,  3.16it/s][A
val E000:   1% 11/950 [00:02<04:12,  3.73it/s][A
val E000:   1% 13/950 [00:03<03:41,  4.24it/s][A
val E000:   2% 15/950 [00:03<03:19,  4.70it/s][A
val E000:   2% 18/950 [00:03<02:53,  5.36it/s][A
val E000:   2% 21/950 [00:03<02:36,  5.94it/s][A
val E000:   2% 23/950 [00:03<02:35,  5.95it/s][A
val E000:   3% 25/950 [00:04<02:30,  6.14it/s][A
val E000:   3% 27/950 [00:04<02:31,  6.10it/s][A
val E000:   3% 29/950 [00:04<02:24,  6.37it/s][A
val E000:   3% 31/950 [00:04<02:23,  6.39it/s][A
val E000:   3% 33/950 [00:05<02:20,  6.52it/s][A
val E000:   4% 35/950 [00:05<02:18,  6.63it/s][A
val E000:   4% 36/950 [00:05<02:17,  6.66it/s][A
val E000:   4% 37/950 [00:05<02:16,  6.69it/s][A
val E000:   4% 39/

val E000:  19% 181/950 [00:22<01:36,  7.96it/s][A
val E000:  19% 182/950 [00:22<01:36,  7.95it/s][A
val E000:  19% 183/950 [00:23<01:36,  7.95it/s][A
val E000:  19% 184/950 [00:23<01:36,  7.96it/s][A
val E000:  19% 185/950 [00:23<01:36,  7.97it/s][A
val E000:  20% 187/950 [00:23<01:35,  7.98it/s][A
val E000:  20% 188/950 [00:23<01:35,  7.99it/s][A
val E000:  20% 189/950 [00:23<01:35,  7.98it/s][A
val E000:  20% 190/950 [00:23<01:35,  7.98it/s][A
val E000:  20% 191/950 [00:23<01:35,  7.98it/s][A
val E000:  20% 193/950 [00:24<01:34,  8.00it/s][A
val E000:  20% 194/950 [00:24<01:34,  7.99it/s][A
val E000:  21% 196/950 [00:24<01:34,  8.00it/s][A
val E000:  21% 197/950 [00:24<01:34,  8.00it/s][A
val E000:  21% 198/950 [00:24<01:33,  8.00it/s][A
val E000:  21% 199/950 [00:24<01:33,  8.01it/s][A
val E000:  21% 201/950 [00:25<01:33,  8.01it/s][A
val E000:  21% 202/950 [00:25<01:33,  8.02it/s][A
val E000:  21% 203/950 [00:25<01:33,  8.03it/s][A
val E000:  21% 204/950 [00:25<0

val E000:  47% 444/950 [01:13<01:24,  6.01it/s][A
val E000:  47% 445/950 [01:13<01:23,  6.01it/s][A
val E000:  47% 446/950 [01:14<01:23,  6.02it/s][A
val E000:  47% 447/950 [01:14<01:23,  6.02it/s][A
val E000:  47% 448/950 [01:14<01:23,  6.02it/s][A
val E000:  47% 450/950 [01:14<01:22,  6.03it/s][A
val E000:  48% 452/950 [01:14<01:22,  6.04it/s][A
val E000:  48% 454/950 [01:15<01:21,  6.05it/s][A
val E000:  48% 455/950 [01:15<01:21,  6.05it/s][A
val E000:  48% 456/950 [01:15<01:21,  6.06it/s][A
val E000:  48% 457/950 [01:15<01:21,  6.06it/s][A
val E000:  48% 458/950 [01:15<01:21,  6.06it/s][A
val E000:  48% 459/950 [01:15<01:20,  6.07it/s][A
val E000:  48% 460/950 [01:15<01:20,  6.07it/s][A
val E000:  49% 461/950 [01:15<01:20,  6.07it/s][A
val E000:  49% 462/950 [01:16<01:20,  6.07it/s][A
val E000:  49% 463/950 [01:16<01:20,  6.07it/s][A
val E000:  49% 464/950 [01:16<01:20,  6.07it/s][A
val E000:  49% 466/950 [01:16<01:19,  6.08it/s][A
val E000:  49% 467/950 [01:16<0

val E000:  70% 664/950 [01:39<00:42,  6.68it/s][A
val E000:  70% 666/950 [01:39<00:42,  6.66it/s][A
val E000:  70% 668/950 [01:40<00:42,  6.68it/s][A
val E000:  71% 670/950 [01:40<00:41,  6.67it/s][A
val E000:  71% 672/950 [01:40<00:41,  6.68it/s][A
val E000:  71% 674/950 [01:40<00:41,  6.67it/s][A
val E000:  71% 676/950 [01:41<00:40,  6.69it/s][A
val E000:  71% 678/950 [01:41<00:40,  6.70it/s][A
val E000:  72% 680/950 [01:41<00:40,  6.71it/s][A
val E000:  72% 682/950 [01:42<00:40,  6.69it/s][A
val E000:  72% 684/950 [01:42<00:39,  6.70it/s][A
val E000:  72% 686/950 [01:42<00:39,  6.71it/s][A
val E000:  73% 689/950 [01:42<00:38,  6.69it/s][A
val E000:  73% 692/950 [01:43<00:38,  6.71it/s][A
val E000:  73% 694/950 [01:43<00:38,  6.72it/s][A
val E000:  73% 697/950 [01:43<00:37,  6.73it/s][A
val E000:  74% 699/950 [01:43<00:37,  6.74it/s][A
val E000:  74% 702/950 [01:43<00:36,  6.76it/s][A
val E000:  74% 704/950 [01:43<00:36,  6.77it/s][A
val E000:  74% 706/950 [01:44<0

val E000:   3% 26/950 [00:04<02:45,  5.60it/s][A
val E000:   3% 28/950 [00:04<02:38,  5.81it/s][A
val E000:   3% 31/950 [00:05<02:38,  5.81it/s][A
val E000:   3% 33/950 [00:05<02:33,  5.97it/s][A
val E000:   4% 35/950 [00:05<02:29,  6.13it/s][A
val E000:   4% 37/950 [00:05<02:23,  6.36it/s][A
val E000:   4% 39/950 [00:06<02:22,  6.37it/s][A
val E000:   4% 41/950 [00:06<02:20,  6.48it/s][A
val E000:   5% 43/950 [00:06<02:17,  6.59it/s][A
val E000:   5% 45/950 [00:06<02:17,  6.59it/s][A
val E000:   5% 47/950 [00:06<02:13,  6.76it/s][A
val E000:   5% 49/950 [00:07<02:11,  6.83it/s][A
val E000:   5% 51/950 [00:07<02:14,  6.71it/s][A
val E000:   6% 53/950 [00:07<02:10,  6.87it/s][A
val E000:   6% 55/950 [00:07<02:08,  6.97it/s][A
val E000:   6% 57/950 [00:08<02:05,  7.11it/s][A
val E000:   6% 59/950 [00:08<02:05,  7.08it/s][A
val E000:   6% 61/950 [00:08<02:03,  7.21it/s][A
val E000:   7% 63/950 [00:08<02:04,  7.15it/s][A
val E000:   7% 65/950 [00:08<02:01,  7.28it/s][A


val E000:  33% 309/950 [00:47<01:38,  6.51it/s][A
val E000:  33% 311/950 [00:47<01:37,  6.53it/s][A
val E000:  33% 313/950 [00:47<01:37,  6.54it/s][A
val E000:  33% 315/950 [00:48<01:36,  6.56it/s][A
val E000:  33% 317/950 [00:48<01:36,  6.55it/s][A
val E000:  33% 318/950 [00:48<01:36,  6.55it/s][A
val E000:  34% 319/950 [00:48<01:36,  6.56it/s][A
val E000:  34% 320/950 [00:48<01:36,  6.55it/s][A
val E000:  34% 321/950 [00:48<01:35,  6.55it/s][A
val E000:  34% 323/950 [00:49<01:35,  6.57it/s][A
val E000:  34% 324/950 [00:49<01:35,  6.57it/s][A
val E000:  34% 325/950 [00:49<01:35,  6.53it/s][A
val E000:  34% 327/950 [00:49<01:35,  6.55it/s][A
val E000:  35% 329/950 [00:50<01:34,  6.57it/s][A
val E000:  35% 331/950 [00:50<01:33,  6.59it/s][A
val E000:  35% 333/950 [00:50<01:33,  6.62it/s][A
val E000:  35% 335/950 [00:50<01:32,  6.63it/s][A
val E000:  35% 337/950 [00:50<01:32,  6.63it/s][A
val E000:  36% 339/950 [00:51<01:32,  6.62it/s][A
val E000:  36% 341/950 [00:51<0

val E000:  66% 629/950 [01:28<00:45,  7.08it/s][A
val E000:  66% 631/950 [01:29<00:44,  7.09it/s][A
val E000:  67% 633/950 [01:29<00:44,  7.09it/s][A
val E000:  67% 635/950 [01:29<00:44,  7.10it/s][A
val E000:  67% 637/950 [01:29<00:44,  7.11it/s][A
val E000:  67% 639/950 [01:29<00:43,  7.11it/s][A
val E000:  67% 641/950 [01:30<00:43,  7.11it/s][A
val E000:  68% 642/950 [01:30<00:43,  7.11it/s][A
val E000:  68% 643/950 [01:30<00:43,  7.11it/s][A
val E000:  68% 644/950 [01:30<00:42,  7.12it/s][A
val E000:  68% 645/950 [01:30<00:42,  7.12it/s][A
val E000:  68% 646/950 [01:30<00:42,  7.12it/s][A
val E000:  68% 647/950 [01:30<00:42,  7.12it/s][A
val E000:  68% 648/950 [01:31<00:42,  7.12it/s][A
val E000:  68% 649/950 [01:31<00:42,  7.12it/s][A
val E000:  68% 650/950 [01:31<00:42,  7.11it/s][A
val E000:  69% 652/950 [01:31<00:41,  7.12it/s][A
val E000:  69% 654/950 [01:31<00:41,  7.14it/s][A
val E000:  69% 656/950 [01:31<00:41,  7.14it/s][A
val E000:  69% 658/950 [01:32<0

val E000:  89% 846/950 [01:54<00:14,  7.41it/s][A
val E000:  89% 847/950 [01:54<00:13,  7.40it/s][A
val E000:  89% 848/950 [01:54<00:13,  7.40it/s][A
val E000:  89% 849/950 [01:54<00:13,  7.40it/s][A
val E000:  90% 851/950 [01:54<00:13,  7.41it/s][A
val E000:  90% 853/950 [01:55<00:13,  7.41it/s][A
val E000:  90% 855/950 [01:55<00:12,  7.42it/s][A
val E000:  90% 856/950 [01:55<00:12,  7.42it/s][A
val E000:  90% 857/950 [01:55<00:12,  7.42it/s][A
val E000:  90% 859/950 [01:55<00:12,  7.42it/s][A
val E000:  91% 860/950 [01:55<00:12,  7.43it/s][A
val E000:  91% 861/950 [01:55<00:11,  7.43it/s][A
val E000:  91% 862/950 [01:56<00:11,  7.43it/s][A
val E000:  91% 863/950 [01:56<00:11,  7.43it/s][A
val E000:  91% 865/950 [01:56<00:11,  7.43it/s][A
val E000:  91% 866/950 [01:56<00:11,  7.44it/s][A
val E000:  91% 867/950 [01:56<00:11,  7.44it/s][A
val E000:  91% 869/950 [01:56<00:10,  7.44it/s][A
val E000:  92% 870/950 [01:56<00:10,  7.44it/s][A
val E000:  92% 871/950 [01:57<0

val E000:  11% 108/950 [00:13<01:48,  7.75it/s][A
val E000:  11% 109/950 [00:14<01:48,  7.76it/s][A
val E000:  12% 110/950 [00:14<01:48,  7.77it/s][A
val E000:  12% 111/950 [00:14<01:47,  7.77it/s][A
val E000:  12% 112/950 [00:14<01:48,  7.75it/s][A
val E000:  12% 113/950 [00:14<01:47,  7.76it/s][A
val E000:  12% 114/950 [00:14<01:47,  7.77it/s][A
val E000:  12% 115/950 [00:14<01:47,  7.78it/s][A
val E000:  12% 117/950 [00:14<01:46,  7.81it/s][A
val E000:  12% 118/950 [00:15<01:46,  7.82it/s][A
val E000:  13% 119/950 [00:15<01:46,  7.83it/s][A
val E000:  13% 120/950 [00:15<01:45,  7.84it/s][A
val E000:  13% 121/950 [00:15<01:45,  7.84it/s][A
val E000:  13% 122/950 [00:15<01:45,  7.83it/s][A
val E000:  13% 123/950 [00:15<01:45,  7.81it/s][A
val E000:  13% 124/950 [00:15<01:45,  7.80it/s][A
val E000:  13% 125/950 [00:16<01:45,  7.80it/s][A
val E000:  13% 126/950 [00:16<01:45,  7.78it/s][A
val E000:  13% 127/950 [00:16<01:46,  7.72it/s][A
val E000:  13% 128/950 [00:16<0

val E000:  31% 293/950 [00:37<01:23,  7.86it/s][A
val E000:  31% 294/950 [00:38<01:25,  7.68it/s][A
val E000:  31% 295/950 [00:39<01:26,  7.53it/s][A
val E000:  31% 298/950 [00:39<01:25,  7.58it/s][A
val E000:  32% 301/950 [00:39<01:25,  7.63it/s][A
val E000:  32% 303/950 [00:40<01:26,  7.45it/s][A
val E000:  32% 305/950 [00:40<01:26,  7.48it/s][A
val E000:  32% 308/950 [00:40<01:25,  7.53it/s][A
val E000:  33% 311/950 [00:41<01:24,  7.52it/s][A
val E000:  33% 313/950 [00:41<01:24,  7.54it/s][A
val E000:  33% 315/950 [00:41<01:23,  7.57it/s][A
val E000:  33% 318/950 [00:41<01:23,  7.61it/s][A
val E000:  34% 320/950 [00:42<01:23,  7.56it/s][A
val E000:  34% 322/950 [00:42<01:22,  7.59it/s][A
val E000:  34% 324/950 [00:42<01:22,  7.62it/s][A
val E000:  34% 326/950 [00:42<01:21,  7.64it/s][A
val E000:  35% 328/950 [00:43<01:22,  7.58it/s][A
val E000:  35% 331/950 [00:43<01:21,  7.62it/s][A
val E000:  35% 333/950 [00:43<01:20,  7.65it/s][A
val E000:  35% 335/950 [00:44<0

val E000:  59% 565/950 [01:10<00:47,  8.05it/s][A
val E000:  60% 567/950 [01:10<00:47,  8.04it/s][A
val E000:  60% 569/950 [01:10<00:47,  8.06it/s][A
val E000:  60% 571/950 [01:10<00:47,  8.06it/s][A
val E000:  60% 573/950 [01:11<00:46,  8.03it/s][A
val E000:  61% 575/950 [01:11<00:46,  8.04it/s][A
val E000:  61% 577/950 [01:11<00:46,  8.06it/s][A
val E000:  61% 579/950 [01:11<00:46,  8.05it/s][A
val E000:  61% 581/950 [01:12<00:46,  8.02it/s][A
val E000:  61% 583/950 [01:12<00:45,  8.03it/s][A
val E000:  62% 585/950 [01:12<00:45,  8.05it/s][A
val E000:  62% 587/950 [01:13<00:45,  8.03it/s][A
val E000:  62% 589/950 [01:13<00:44,  8.03it/s][A
val E000:  62% 591/950 [01:13<00:44,  8.05it/s][A
val E000:  62% 593/950 [01:13<00:44,  8.06it/s][A
val E000:  63% 595/950 [01:14<00:44,  8.04it/s][A
val E000:  63% 597/950 [01:14<00:43,  8.04it/s][A
val E000:  63% 599/950 [01:14<00:43,  8.06it/s][A
val E000:  63% 601/950 [01:14<00:43,  8.07it/s][A
val E000:  63% 603/950 [01:15<0

val E000:  87% 822/950 [01:40<00:15,  8.17it/s][A
val E000:  87% 824/950 [01:40<00:15,  8.17it/s][A
val E000:  87% 825/950 [01:41<00:15,  8.17it/s][A
val E000:  87% 826/950 [01:41<00:15,  8.17it/s][A
val E000:  87% 827/950 [01:41<00:15,  8.17it/s][A
val E000:  87% 828/950 [01:41<00:14,  8.17it/s][A
val E000:  87% 829/950 [01:41<00:14,  8.17it/s][A
val E000:  87% 830/950 [01:41<00:14,  8.17it/s][A
val E000:  87% 831/950 [01:41<00:14,  8.18it/s][A
val E000:  88% 832/950 [01:41<00:14,  8.18it/s][A
val E000:  88% 833/950 [01:41<00:14,  8.18it/s][A
val E000:  88% 834/950 [01:41<00:14,  8.18it/s][A
val E000:  88% 835/950 [01:42<00:14,  8.18it/s][A
val E000:  88% 836/950 [01:42<00:13,  8.18it/s][A
val E000:  88% 837/950 [01:42<00:13,  8.17it/s][A
val E000:  88% 838/950 [01:42<00:13,  8.18it/s][A
val E000:  88% 839/950 [01:42<00:13,  8.17it/s][A
val E000:  89% 841/950 [01:42<00:13,  8.18it/s][A
val E000:  89% 842/950 [01:42<00:13,  8.18it/s][A
val E000:  89% 843/950 [01:43<0

val E000:  10% 97/950 [00:12<01:48,  7.84it/s][A
val E000:  10% 98/950 [00:12<01:48,  7.85it/s][A
val E000:  10% 99/950 [00:12<01:48,  7.86it/s][A
val E000:  11% 100/950 [00:12<01:48,  7.86it/s][A
val E000:  11% 101/950 [00:12<01:47,  7.87it/s][A
val E000:  11% 102/950 [00:13<01:49,  7.77it/s][A
val E000:  11% 104/950 [00:13<01:47,  7.86it/s][A
val E000:  11% 106/950 [00:13<01:47,  7.88it/s][A
val E000:  11% 108/950 [00:13<01:47,  7.87it/s][A
val E000:  11% 109/950 [00:13<01:47,  7.85it/s][A
val E000:  12% 110/950 [00:14<01:47,  7.83it/s][A
val E000:  12% 111/950 [00:14<01:47,  7.83it/s][A
val E000:  12% 112/950 [00:14<01:46,  7.84it/s][A
val E000:  12% 113/950 [00:14<01:46,  7.82it/s][A
val E000:  12% 114/950 [00:14<01:46,  7.83it/s][A
val E000:  12% 115/950 [00:14<01:46,  7.84it/s][A
val E000:  12% 116/950 [00:14<01:46,  7.85it/s][A
val E000:  12% 117/950 [00:14<01:46,  7.84it/s][A
val E000:  13% 119/950 [00:15<01:45,  7.87it/s][A
val E000:  13% 120/950 [00:15<01:4

val E000:  33% 316/950 [00:38<01:17,  8.20it/s][A
val E000:  33% 317/950 [00:38<01:17,  8.18it/s][A
val E000:  34% 319/950 [00:38<01:16,  8.20it/s][A
val E000:  34% 321/950 [00:39<01:16,  8.21it/s][A
val E000:  34% 323/950 [00:39<01:16,  8.21it/s][A
val E000:  34% 324/950 [00:39<01:16,  8.21it/s][A
val E000:  34% 325/950 [00:39<01:16,  8.18it/s][A
val E000:  34% 327/950 [00:39<01:15,  8.20it/s][A
val E000:  35% 328/950 [00:40<01:15,  8.19it/s][A
val E000:  35% 330/950 [00:40<01:15,  8.20it/s][A
val E000:  35% 331/950 [00:40<01:15,  8.20it/s][A
val E000:  35% 332/950 [00:40<01:15,  8.21it/s][A
val E000:  35% 333/950 [00:40<01:15,  8.17it/s][A
val E000:  35% 335/950 [00:40<01:15,  8.20it/s][A
val E000:  35% 337/950 [00:41<01:14,  8.22it/s][A
val E000:  36% 339/950 [00:41<01:14,  8.23it/s][A
val E000:  36% 341/950 [00:41<01:13,  8.24it/s][A
val E000:  36% 343/950 [00:41<01:13,  8.22it/s][A
val E000:  36% 344/950 [00:41<01:13,  8.22it/s][A
val E000:  36% 345/950 [00:41<0

val E000:  61% 575/950 [01:13<00:48,  7.81it/s][A
val E000:  61% 576/950 [01:13<00:47,  7.80it/s][A
val E000:  61% 578/950 [01:13<00:47,  7.82it/s][A
val E000:  61% 579/950 [01:14<00:47,  7.81it/s][A
val E000:  61% 580/950 [01:14<00:47,  7.81it/s][A
val E000:  61% 581/950 [01:14<00:47,  7.81it/s][A
val E000:  61% 582/950 [01:14<00:47,  7.82it/s][A
val E000:  61% 583/950 [01:14<00:46,  7.82it/s][A
val E000:  61% 584/950 [01:14<00:46,  7.82it/s][A
val E000:  62% 585/950 [01:14<00:46,  7.82it/s][A
val E000:  62% 586/950 [01:14<00:46,  7.83it/s][A
val E000:  62% 587/950 [01:14<00:46,  7.83it/s][A
val E000:  62% 588/950 [01:15<00:46,  7.83it/s][A
val E000:  62% 589/950 [01:15<00:46,  7.83it/s][A
val E000:  62% 590/950 [01:15<00:45,  7.83it/s][A
val E000:  62% 591/950 [01:15<00:45,  7.83it/s][A
val E000:  62% 592/950 [01:15<00:45,  7.83it/s][A
val E000:  62% 593/950 [01:15<00:45,  7.83it/s][A
val E000:  63% 594/950 [01:15<00:45,  7.83it/s][A
val E000:  63% 595/950 [01:15<0

val E000:  85% 811/950 [01:41<00:17,  8.03it/s][A
val E000:  86% 813/950 [01:41<00:17,  8.03it/s][A
val E000:  86% 814/950 [01:41<00:16,  8.03it/s][A
val E000:  86% 816/950 [01:41<00:16,  8.04it/s][A
val E000:  86% 817/950 [01:41<00:16,  8.04it/s][A
val E000:  86% 818/950 [01:41<00:16,  8.02it/s][A
val E000:  86% 820/950 [01:42<00:16,  8.03it/s][A
val E000:  87% 822/950 [01:42<00:15,  8.04it/s][A
val E000:  87% 824/950 [01:42<00:15,  8.04it/s][A
val E000:  87% 826/950 [01:42<00:15,  8.05it/s][A
val E000:  87% 828/950 [01:42<00:15,  8.04it/s][A
val E000:  87% 829/950 [01:43<00:15,  8.05it/s][A
val E000:  87% 831/950 [01:43<00:14,  8.05it/s][A
val E000:  88% 832/950 [01:43<00:14,  8.05it/s][A
val E000:  88% 833/950 [01:43<00:14,  8.05it/s][A
val E000:  88% 834/950 [01:43<00:14,  8.05it/s][A
val E000:  88% 836/950 [01:43<00:14,  8.06it/s][A
val E000:  88% 838/950 [01:43<00:13,  8.06it/s][A
val E000:  88% 839/950 [01:44<00:13,  8.06it/s][A
val E000:  89% 841/950 [01:44<0

val E000:   9% 85/950 [00:11<01:52,  7.69it/s][A
val E000:   9% 87/950 [00:11<01:51,  7.72it/s][A
val E000:   9% 89/950 [00:11<01:51,  7.75it/s][A
val E000:  10% 91/950 [00:11<01:50,  7.78it/s][A
val E000:  10% 93/950 [00:11<01:49,  7.81it/s][A
val E000:  10% 94/950 [00:12<01:49,  7.81it/s][A
val E000:  10% 95/950 [00:12<01:49,  7.82it/s][A
val E000:  10% 96/950 [00:12<01:49,  7.82it/s][A
val E000:  10% 97/950 [00:12<01:48,  7.83it/s][A
val E000:  10% 98/950 [00:12<01:49,  7.76it/s][A
val E000:  11% 101/950 [00:12<01:47,  7.88it/s][A
val E000:  11% 103/950 [00:13<01:47,  7.91it/s][A
val E000:  11% 104/950 [00:13<01:47,  7.90it/s][A
val E000:  11% 105/950 [00:13<01:46,  7.91it/s][A
val E000:  11% 106/950 [00:13<01:46,  7.93it/s][A
val E000:  11% 108/950 [00:13<01:45,  7.96it/s][A
val E000:  12% 110/950 [00:13<01:45,  7.99it/s][A
val E000:  12% 111/950 [00:13<01:44,  8.00it/s][A
val E000:  12% 113/950 [00:14<01:44,  8.03it/s][A
val E000:  12% 114/950 [00:14<01:43,  8.0

val E000:  31% 296/950 [00:37<01:23,  7.82it/s][A
val E000:  31% 297/950 [00:37<01:23,  7.82it/s][A
val E000:  31% 298/950 [00:38<01:23,  7.83it/s][A
val E000:  31% 299/950 [00:38<01:23,  7.83it/s][A
val E000:  32% 300/950 [00:38<01:22,  7.84it/s][A
val E000:  32% 301/950 [00:38<01:22,  7.84it/s][A
val E000:  32% 302/950 [00:38<01:22,  7.83it/s][A
val E000:  32% 303/950 [00:38<01:22,  7.83it/s][A
val E000:  32% 304/950 [00:38<01:22,  7.83it/s][A
val E000:  32% 306/950 [00:39<01:22,  7.84it/s][A
val E000:  32% 307/950 [00:39<01:22,  7.84it/s][A
val E000:  32% 308/950 [00:39<01:21,  7.84it/s][A
val E000:  33% 310/950 [00:39<01:21,  7.85it/s][A
val E000:  33% 312/950 [00:39<01:21,  7.86it/s][A
val E000:  33% 313/950 [00:39<01:20,  7.87it/s][A
val E000:  33% 314/950 [00:39<01:20,  7.87it/s][A
val E000:  33% 315/950 [00:40<01:20,  7.87it/s][A
val E000:  33% 317/950 [00:40<01:20,  7.89it/s][A
val E000:  33% 318/950 [00:40<01:20,  7.89it/s][A
val E000:  34% 319/950 [00:40<0

val E000:  52% 490/950 [01:00<00:56,  8.11it/s][A
val E000:  52% 491/950 [01:00<00:56,  8.11it/s][A
val E000:  52% 492/950 [01:00<00:56,  8.11it/s][A
val E000:  52% 494/950 [01:00<00:56,  8.12it/s][A
val E000:  52% 495/950 [01:00<00:56,  8.12it/s][A
val E000:  52% 496/950 [01:01<00:55,  8.11it/s][A
val E000:  52% 497/950 [01:01<00:55,  8.11it/s][A
val E000:  52% 498/950 [01:01<00:55,  8.12it/s][A
val E000:  53% 499/950 [01:01<00:55,  8.12it/s][A
val E000:  53% 501/950 [01:01<00:55,  8.13it/s][A
val E000:  53% 503/950 [01:01<00:54,  8.13it/s][A
val E000:  53% 505/950 [01:02<00:54,  8.14it/s][A
val E000:  53% 507/950 [01:02<00:54,  8.14it/s][A
val E000:  54% 509/950 [01:02<00:54,  8.14it/s][A
val E000:  54% 510/950 [01:02<00:54,  8.14it/s][A
val E000:  54% 511/950 [01:02<00:53,  8.14it/s][A
val E000:  54% 512/950 [01:02<00:53,  8.14it/s][A
val E000:  54% 513/950 [01:03<00:53,  8.14it/s][A
val E000:  54% 514/950 [01:03<00:53,  8.14it/s][A
val E000:  54% 515/950 [01:03<0

val E000:  76% 724/950 [01:28<00:27,  8.17it/s][A
val E000:  76% 725/950 [01:28<00:27,  8.17it/s][A
val E000:  76% 726/950 [01:28<00:27,  8.17it/s][A
val E000:  77% 727/950 [01:28<00:27,  8.17it/s][A
val E000:  77% 728/950 [01:29<00:27,  8.18it/s][A
val E000:  77% 729/950 [01:29<00:27,  8.18it/s][A
val E000:  77% 730/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 731/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 732/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 733/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 734/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 735/950 [01:29<00:26,  8.18it/s][A
val E000:  77% 736/950 [01:29<00:26,  8.19it/s][A
val E000:  78% 737/950 [01:30<00:26,  8.19it/s][A
val E000:  78% 738/950 [01:30<00:25,  8.19it/s][A
val E000:  78% 739/950 [01:30<00:25,  8.19it/s][A
val E000:  78% 740/950 [01:30<00:25,  8.19it/s][A
val E000:  78% 741/950 [01:30<00:25,  8.19it/s][A
val E000:  78% 742/950 [01:30<00:25,  8.20it/s][A
val E000:  78% 744/950 [01:30<0

val E000:  96% 914/950 [01:50<00:04,  8.30it/s][A
val E000:  96% 915/950 [01:50<00:04,  8.30it/s][A
val E000:  96% 916/950 [01:50<00:04,  8.30it/s][A
val E000:  97% 917/950 [01:50<00:03,  8.30it/s][A
val E000:  97% 918/950 [01:50<00:03,  8.30it/s][A
val E000:  97% 919/950 [01:50<00:03,  8.30it/s][A
val E000:  97% 920/950 [01:50<00:03,  8.30it/s][A
val E000:  97% 921/950 [01:50<00:03,  8.30it/s][A
val E000:  97% 922/950 [01:51<00:03,  8.30it/s][A
val E000:  97% 923/950 [01:51<00:03,  8.30it/s][A
val E000:  97% 925/950 [01:51<00:03,  8.30it/s][A
val E000:  98% 927/950 [01:51<00:02,  8.31it/s][A
val E000:  98% 928/950 [01:51<00:02,  8.31it/s][A
val E000:  98% 929/950 [01:51<00:02,  8.31it/s][A
val E000:  98% 930/950 [01:51<00:02,  8.31it/s][A
val E000:  98% 931/950 [01:52<00:02,  8.31it/s][A
val E000:  98% 932/950 [01:52<00:02,  8.31it/s][A
val E000:  98% 933/950 [01:52<00:02,  8.30it/s][A
val E000:  98% 935/950 [01:52<00:01,  8.31it/s][A
val E000:  99% 936/950 [01:52<0

Prefix attacks:
[('in not a lot of words', 0.4433883073276713), ('in not many words', 0.39215056949107907), ('what is the answer to', 0.3814191190993482), ('tell me', 0.6119148726051747), ('answer this', 0.6465534268220423)]


In [63]:
print("Prefix attacks:")
pd.DataFrame([w for w in zip(PHRASES, prefix_attack_accs)], columns=['prefix phrase', 'accuracy'])

Prefix attacks:


Unnamed: 0,prefix phrase,accuracy
0,in not a lot of words,0.443388
1,in not many words,0.392151
2,what is the answer to,0.381419
3,tell me,0.611915
4,answer this,0.646553
5,answer this for me,0.600385


## Overstability analysis

In [64]:
counts_list = []
top_k = 1
with open(ATTRS_TSV) as f:
    for line in f:
        line = line.strip()
        question_attrs = line.split('\t')[0]
        question_tokens = []
        attrs = []
        for word_attr in question_attrs.split('||'): 
            word, attr = word_attr.split('|')
            question_tokens.append(word)
            attrs.append(float(attr))
        k = min(top_k, len(question_tokens))
        # get top k words by attribution 
        counts_list.extend([question_tokens[i].strip() for i in np.argpartition(attrs, -k)[-k:]])

In [65]:
paper_whitelist = [vocab_json['question'][w] for w in 'the, is, what, are, this, in, on, a, of, how, many, color, there, people, where'.split(', ')]

In [66]:
Counter(counts_list).most_common(10)

[('color', 9252),
 ('many', 7274),
 ('what', 4675),
 ('is', 3119),
 ('there', 2321),
 ('how', 2213),
 ('doing', 2073),
 ('or', 2049),
 ('where', 1692),
 ('are', 1455)]

In [None]:
curve_data = {}
all_accs = []
for K in np.unique(np.floor(np.geomspace(1, len(Counter(counts_list)), 50))):
    # take K most top attributed words
    if K in curve_data:
        continue
    whitelist = set([vocab_json['question'][w] for w, c in Counter(counts_list).most_common(int(K))])
    print(len(whitelist))
    accs = []
    num_batches = 0
    avg_question_length_orig = 0
    avg_question_length_new = 0
    num_questions = 0
    # iterator over the validation dataset
    tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
    for v, q, a, idx, q_len in tq:

        old_q = np.asarray(q).copy()
        old_q_len = np.asarray(q_len).copy()

        new_q = np.zeros([config.batch_size, 23])
        curr_batch_size = int(q.shape[0])
        for batch_i in range(curr_batch_size):
            len_counter = 0
            avg_question_length_orig += int(q_len[batch_i])
            for word_i, w in enumerate(q[batch_i,:int(q_len[batch_i])]):
                if int(w) in whitelist:
                    new_q[batch_i, len_counter] = int(w)
                    #new_q[batch_i, word_i] = int(w)
                    len_counter += 1
            if len_counter == 0:
                len_counter = 1
            avg_question_length_new += int(len_counter)
            num_questions += 1
            q_len[batch_i] = len_counter
        q_len, sorted_idxs = torch.sort(q_len, descending=True)
        new_q = new_q[sorted_idxs, :]
        idx = idx[sorted_idxs]
        v = v[sorted_idxs,:,:,:]
        a = a[sorted_idxs, :]
        old_q = old_q[sorted_idxs, :]
        old_q_len = old_q_len[sorted_idxs]
        q = torch.LongTensor(new_q)

        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)

        q_emb = embedding(q)

        out = net(v, q_emb, q_len)

        acc = utils.batch_accuracy(out.data, a.data).cpu()

        accs.append(np.array(acc.view(-1)))
        del v, q, a, idx, q_len, q_emb, acc, out, old_q, sorted_idxs, old_q_len, new_q

        if num_batches >= MAX_NUM_BATCHES:
            break
        num_batches += 1

    accs = list(np.concatenate(accs, axis=0))
    print("avg question length orig: ", float(avg_question_length_orig)/num_questions)
    print("avg question length new: ", float(avg_question_length_new)/num_questions)
    print("accuracy for ", K, " is", np.mean(accs))
    curve_data[K] = np.mean(accs)


val E000:   0% 0/950 [00:00<?, ?it/s][A

1


  attention = F.softmax(attention)

val E000:   0% 1/950 [00:02<40:43,  2.57s/it][A
val E000:   0% 2/950 [00:03<24:57,  1.58s/it][A
val E000:   0% 3/950 [00:03<19:29,  1.23s/it][A
val E000:   0% 4/950 [00:04<16:40,  1.06s/it][A
val E000:   1% 5/950 [00:04<14:54,  1.06it/s][A
val E000:   1% 6/950 [00:05<13:50,  1.14it/s][A
val E000:   1% 7/950 [00:05<12:58,  1.21it/s][A
val E000:   1% 8/950 [00:06<12:21,  1.27it/s][A
val E000:   1% 9/950 [00:06<11:53,  1.32it/s][A
val E000:   1% 10/950 [00:07<11:30,  1.36it/s][A
val E000:   1% 11/950 [00:07<11:11,  1.40it/s][A
val E000:   1% 12/950 [00:08<10:54,  1.43it/s][A
val E000:   1% 13/950 [00:08<10:40,  1.46it/s][A
val E000:   1% 14/950 [00:09<10:27,  1.49it/s][A
val E000:  96% 916/950 [52:23<01:56,  3.43s/it]

In [None]:
print("Accuracies by size of vocab")
print(curve_data)

In [None]:
plt.plot([w[0] for w in curve_data], [w[1] for w in curve_data])
plt.xscale('log')
plt.xlabel('num. words in vocab')
plt.ylabel('accuracy')
plt.savefig(OVERSTABILITY_CURVE_FILE, format='eps')
plt.show()

## Subject ablation attack

In [None]:
counts_list = []
top_k = 5
with open(ATTRS_TSV) as f:
    for line in f:
        line = line.strip()
        question_attrs, predicted_answer, correct_answer, is_correct, image_id = line.split('\t')
        question_tokens = []
        attrs = []
        for word_attr in question_attrs.split(','):
            if len(word_attr.split('|')) < 2:
                print('skipped')
                continue
            word, attr = word_attr.split('|')
            question_tokens.append(word)
            attrs.append(float(attr))
        k = min(top_k, len(question_tokens))
        counts_list.extend([question_tokens[i].strip() for i in np.argpartition(attrs, -k)[-k:]])

In [None]:
unattributed_words = set(vocab_json['question'].keys()) - set(counts_list)

In [None]:
unattributed_words

In [None]:
list(unattributed_words)[-10:]

In [None]:
nlp = spacy.load('en')
sent = "how symmetrical are the white bricks on either side of the building"
doc=nlp(sent)

sub_toks = [tok for tok in doc if (tok.dep_ == "nobj") ]

print(sub_toks)

In [None]:
[d.dep_ for d in doc]

In [None]:
tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
net.eval()
answ = []
idxs = []
accs = []
num_iters = 0
batch_id = 0
attrs_tsv_string = ''
for v, q, a, idx, q_len in tq:
        
        var_params = {
            'volatile': False,
            'requires_grad': False,        
        }
        out_string = ''
        for i in range(config.batch_size):
            if len(np.nonzero(a[i, :] >= 3))==0:
                continue
            answers = [reverse_vocab_answer[int(w)] for w in np.nonzero(a[i, :] >= 3)]
            if 'yes' in answers or 'no' in answers:
                continue
            string_question = [reverse_vocab_question[int(w)] if int(w) != 0 else '' for w in q[i, :]]
            out_string += '-'*50 + '\n'
            out_string += 'orig: ' + ' '.join(string_question) + '\n'
            out_string += 'answers: ' + ' '.join(answers) + '\n'
            doc = nlp(' '.join(string_question))
            pos_tags = [d.dep_ for d in doc]
            #print(pos_tags)
            subject_index = [i for i, t in enumerate(pos_tags) if 'nsubj' in t]
            if len(subject_index) == 0:
                continue
            q[i, subject_index[0]] = vocab_json['question']['civilian']
            string_question = [reverse_vocab_question[int(w)] if int(w) != 0 else '' for w in q[i, :]]
            out_string += 'ablated: ' + ' '.join(string_question) + '\n'

            
        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)
        
        q_emb = embedding(q)
        
        out = net(v, q_emb, q_len)            
        
        acc = utils.batch_accuracy(out.data, a.data).cpu()
        
        _, answer = out.data.cpu().max(dim=1)
        
#        attrs_tsv_string = compute_attributions(q_emb, q_len, v, idx, num_batches=5)
        
#        outf.write(attrs_tsv_string)
        #for i in range(config.batch_size):
            #if int(acc[i]) >= 1.0:
                #print(out_string)
        
        answ.append(answer.view(-1))
        accs.append(acc.view(-1))
        idxs.append(idx.view(-1).clone())
        print(acc.mean())
        num_iters += 1
        batch_id += 1
        #print(' '.join([reverse_vocab_question[int(w)] for w in q[3,:] if int(w)!=0]))
        #print(acc[3])
        if num_iters == 5:
            break
            
#outf.close()

answ = list(torch.cat(answ, dim=0))
accs = list(torch.cat(accs, dim=0))
idxs = list(torch.cat(idxs, dim=0))

print('final: ' + str(np.mean(accs)))

In [None]:
reverse_vocab_answer[int(np.nonzero(a[0, :] > 3))]

In [None]:
np.nonzero(a[i, :] >= 3)

## Image specific bias

In [None]:
import json
import scipy.stats as stats
from collections import Counter

In [None]:
json_data=open('/scratch/pramodkm/vqa/data_vqa1.0/OpenEnded_mscoco_val2014_questions.json').read()
data = json.loads(json_data)

In [None]:
data['questions']

In [None]:
json_data=open('/scratch/pramodkm/vqa/data_vqa1.0/mscoco_val2014_annotations.json').read()
annot_data = json.loads(json_data)

In [None]:
annot_data['annotations']

In [None]:
image_ans = dict()
for ans_annot in annot_data['annotations']:
    turk_answers = [ans['answer'] for ans in ans_annot['answers']]
    if ans_annot['image_id'] not in image_ans:
        image_ans[ans_annot['image_id']] = [turk_answers]
    else:
        image_ans[ans_annot['image_id']].append(turk_answers)

In [None]:
image_ans

In [None]:
def visualize_baseline_answers(tokens, attrs, image_ans):
    html_text = ""
    count = 0
    for i, tok in enumerate(tokens):
        r,g,b = get_color(attrs[i])
        val = []
        for ans in image_ans:
            val += [sum(tok == np.array(ans))]
        if sum(np.array(val)>=3)>0:
            tok = '<u>' + tok + '</u>'
            count += 1
        html_text += "<span style='size:16;color:rgb(%d,%d,%d)'>%s</span>, " % (r, g, b, tok)
    return html_text, count

In [None]:
tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
net.eval()
answ = []
idxs = []
accs = []
num_iters = 0
batch_id = 0
outf = open('/scratch/pramodkm/vqa/tsv/baseline_answers.html','w')
html_str = '<html><head><link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"></head>'
html_str += '<body><div class="container"> <h3> Top 15 answer classes for each image </h3><br>Generated by passing an empty question to the network. Underlined classes appear as answers to questions on the image. <br>'
html_str += '<table class="table">'
outf.write(html_str)

question = 'what color besides blue is there'
question_tokens = torch.LongTensor([vocab_json['question'][w] for w in question.strip().split()] + [0]*(23-len(question.strip().split())))
question_tokens = question_tokens.unsqueeze(0).repeat(config.batch_size,1)
new_q_len = torch.LongTensor([len(question.strip().split())])
new_q_len = new_q_len.repeat(config.batch_size)

covered_image_ids = set()

avg_count = []
for v, q, a, idx, q_len in tq:
        
        var_params = {
            'volatile': False,
            'requires_grad': False,        
        }
        
        q = question_tokens
            
        q_len = new_q_len
        
        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)
        
        q_emb = embedding(q)
                
        out = net(v, q_emb, q_len)    
        
        softmax = torch.nn.functional.softmax(out)
        
        acc = utils.batch_accuracy(out.data, a.data).cpu()
        
        _, answer = out.data.cpu().max(dim=1)

        # for baseline answers
        baseline_q = q * 0   
            
        baseline_q_len = q_len/q_len
        
        baseline_q_emb = embedding(baseline_q)
                
        baseline_out = net(v, baseline_q_emb, baseline_q_len)    
        
        baseline_softmax = torch.nn.functional.softmax(baseline_out)


        for batch_i in range(config.batch_size):
            baseline_probs, baseline_idxs = baseline_softmax[batch_i, :].sort(descending=True)
            baseline_answers = [reverse_vocab_answer[int(ix)] for ix in baseline_idxs]
            
            baseline_probs = [float(prob) for prob in baseline_probs]
            print_k = 15
            #outf.write('... ' + visualize_baseline_answers(baseline_answers[-10:], baseline_probs[-10:]))
            image_id = str(val_loader.dataset.coco_ids[int(idx[batch_i])])
            if image_id in covered_image_ids:
                continue
            covered_image_ids.add(image_id)
            #outf.write('<br>Question: ' + ' '.join([reverse_vocab_question[int(w)] for w in q[batch_i, :] if int(w)!=0]))
            #outf.write('<br>Pred. ans.: ' + reverse_vocab_answer[answer[batch_i]])
            outf.write('<br><tr><td><img src="val2014/COCO_val2014_' + '0'*(12 - len(str(image_id))) + str(image_id) + '.jpg" width="256" height="256"></img></td>')
            vis_string, count = visualize_baseline_answers(baseline_answers[:print_k], baseline_probs[:print_k], image_ans[int(image_id)])
            avg_count.append(count)
            outf.write('<td>' + vis_string + '<br> #classes appearing as answers: ' + str(count) + '</td></tr>')
            
            outf.write('<hr>')
                
        answ.append(answer.view(-1))
        accs.append(acc.view(-1))
        idxs.append(idx.view(-1).clone())
        num_iters += 1
        batch_id += 1
        if num_iters == 1:
            break
            
outf.write('</table></div></body></html>')
outf.close()

print(np.mean(avg_count))
answ = list(torch.cat(answ, dim=0))
accs = list(torch.cat(accs, dim=0))
idxs = list(torch.cat(idxs, dim=0))

In [None]:
vocab_json['answer']['wood']

In [None]:
tq = tqdm(LOADER, desc='{} E{:03d}'.format(PREFIX, 0), ncols=0)
net.eval()
answ = []
idxs = []
accs = []
num_iters = 0
batch_id = 0
outf = open('/scratch/pramodkm/vqa/tsv/baseline_answers.html','w')
html_str = '<head><link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"></head>'
html_str += '<body><div class="container"> '
outf.write(html_str)

question = 'how many are not'
question_tokens = torch.LongTensor([vocab_json['question'][w] for w in question.strip().split()] + [0]*(23-len(question.strip().split())))
question_tokens = question_tokens.unsqueeze(0).repeat(config.batch_size,1)
new_q_len = torch.LongTensor([len(question.strip().split())])
new_q_len = new_q_len.repeat(config.batch_size)

batch_percentages = np.zeros(3000)
for v, q, a, idx, q_len in tq:
        
        var_params = {
            'volatile': False,
            'requires_grad': False,        
        }
                
        v = Variable(v.cuda(async=True), **var_params)
        q = Variable(q.cuda(async=True), **var_params)
        question_tokens = Variable(question_tokens.cuda(async=True), **var_params)
        a = Variable(a.cuda(async=True), **var_params)
        q_len = Variable(q_len.cuda(async=True), **var_params)
        new_q_len = Variable(new_q_len.cuda(async=True), **var_params)
        
        
        # for baseline answers
        baseline_q = q * 0   
            
        baseline_q_len = q_len/q_len
        
        baseline_q_emb = embedding(baseline_q)
                
        baseline_out = net(v, baseline_q_emb, baseline_q_len)    
        
        baseline_softmax = torch.nn.functional.softmax(baseline_out)

        
        test_k = 300          
        test_q = question_tokens


        batch_baseline_answers = []
        for batch_i in range(config.batch_size):
            baseline_probs, baseline_idxs = baseline_softmax[batch_i, :].sort(descending=True)
            baseline_answers = [reverse_vocab_answer[int(ix)] for ix in baseline_idxs]
            
            counter = 4
            for ba in baseline_answers[2:]:
                if counter == 7:
                    break
                if ba not in vocab_json['question']:
                    continue
                test_q[batch_i, counter] = vocab_json['question'][ba]
                counter += 1
            test_out = net(v, embedding(test_q), new_q_len)
            _, answer = test_out.data.cpu().max(dim=1)

        for batch_i in range(config.batch_size):
            baseline_probs, baseline_idxs = baseline_softmax[batch_i, :].sort(descending=True)
            baseline_answers = [reverse_vocab_answer[int(ix)] for ix in baseline_idxs]
            
            baseline_probs = [float(prob) for prob in baseline_probs]
            print_k = 100
            outf.write(visualize_baseline_answers(baseline_answers[:print_k], baseline_probs[:print_k]))
            outf.write('... ' + visualize_baseline_answers(baseline_answers[-10:], baseline_probs[-10:]))
            image_id = str(val_loader.dataset.coco_ids[int(idx[batch_i])])
            outf.write('<br>Question: ' + ' '.join([reverse_vocab_question[int(w)] for w in test_q[batch_i, :] if int(w)!=0]))
            outf.write('<br>Pred. ans.: ' + reverse_vocab_answer[answer[batch_i]])
            outf.write('<br><img src="val2014/COCO_val2014_' + '0'*(12 - len(str(image_id))) + str(image_id) + '.jpg" width="256" height="256"></img><br><br>')
            outf.write('<hr>')
                      
        

        answ.append(answer.view(-1))
        accs.append(acc.view(-1))
        idxs.append(idx.view(-1).clone())
        num_iters += 1
        batch_id += 1
        if num_iters == 1:
            break
            
outf.write('</div></body>')
outf.close()
answ = list(torch.cat(answ, dim=0))
accs = list(torch.cat(accs, dim=0))
idxs = list(torch.cat(idxs, dim=0))

In [None]:
plt.plot(batch_percentages[:300]/300)

In [None]:
batch_percentages/300

In [None]:
vocab_json['question']['rooster']

In [None]:
set(vocab_json['question'].keys()) - set(vocab_json['answer'].keys())