In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import random
import seaborn as sns; sns.set()
import numpy as np
import torch

In [166]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import *
model_path = './output/Yelp'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path,output_attentions=True)
model.cuda()
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [91]:
def send_sent_batch_to_model(model, sents):
    input_ids = []
    for ii in range(len(sents)):
        input_ids.append( tokenizer.convert_tokens_to_ids(["[CLS]"] + sents[ii] + ["[SEP]"]) )
    input_mask = [[1]*len(ii) for ii in input_ids]
    max_input_len = max(len(ii) for ii in input_ids)
    padded_input_ids = []
    padded_input_mask = []
    padded_token_type_ids = []
    for ii in range(len(input_ids)):
        pad_zeros = [0] * (max_input_len - len(input_ids[ii]))
        padded_input_ids.append(input_ids[ii] + pad_zeros)
        padded_input_mask.append(input_mask[ii] + pad_zeros)
        padded_token_type_ids.append([0] * len(padded_input_ids[ii]))
    padded_input_ids = torch.tensor(padded_input_ids, dtype=torch.long).to(device)
    padded_input_mask = torch.tensor(padded_input_mask, dtype=torch.long).to(device)
    padded_token_type_ids = torch.tensor(padded_token_type_ids, dtype=torch.long).to(device)
    

    logits = model(padded_input_ids, padded_input_mask, padded_token_type_ids)
    encoder_output = model.bert(padded_input_ids, padded_input_mask, padded_token_type_ids)

    return logits[0], encoder_output[1]

In [427]:
def attack_close_words_in_embedding_space(sent, original_label):
    distance_threshold_max = 0.8
    dist_func = torch.nn.L1Loss(reduction='none')

    sent = tokenizer.tokenize(sent)[:80]
    sent_len = len(sent)

    target_label = 1 if original_label == 0 else 0
    target_label_onehot = [0., 1.] if target_label == 1 else [1., 0.]

    model_out_logits, model_encoder_output = send_sent_batch_to_model(model, [sent])
    model_out_logits = model_out_logits.squeeze()
    model_encoder_output = model_encoder_output.squeeze()

    if torch.max(model_out_logits, -1)[1] == target_label:
        # print("Skipped.")
        return -1

    new_sents = []
    for target_word_pos in range(sent_len):
        n_sent = sent[:target_word_pos] + ['[PAD]'] + sent[target_word_pos+1:]
        new_sents.append(n_sent)

    model_out_batch_changed_logits, _ = send_sent_batch_to_model(model, new_sents)

    model_out_logits_dup = model_out_logits.expand(len(new_sents), model_out_logits.size()[0])

    y_target = torch.tensor([target_label_onehot * len(new_sents)]).reshape((len(new_sents), -1)).to(device)

    change_amount = torch.sum(torch.mul((model_out_batch_changed_logits - model_out_logits_dup), y_target), dim=-1)

    _, change_amount_pos = torch.sort(change_amount, dim=-1, descending=True)
    
    change_amount_pos = change_amount_pos.cpu().numpy().tolist()
#     print(change_amount_pos)
    
    for max_change_pos in change_amount_pos:
        target_word = sent[max_change_pos]
        print(target_word)
        cand_word_ids = []
        cand_word_distances = []
        cand_words = []
        a_word_id = 106
        test_batch_size = 600
        target_label_tensor = torch.tensor([target_label] * test_batch_size).to(device)
        # -----
        new_sents = []
        this_batch_cand_word_ids = []
        while a_word_id < len(all_words):
            while len(new_sents) < test_batch_size:
                new_sents.append(sent[:max_change_pos] + [all_words[a_word_id]] + sent[max_change_pos+1:])
                this_batch_cand_word_ids.append(a_word_id)
                a_word_id += 1
                if a_word_id >= len(all_words): break

            model_changed_logit, model_changed_encoder_output = send_sent_batch_to_model(model, new_sents)
            model_changed_label = torch.max(model_changed_logit, -1)[1]

            if len(new_sents) < test_batch_size:                                                                                 
                target_label_tensor = torch.tensor([target_label] * len(new_sents)).to(device)
            attack_success = torch.eq(model_changed_label, target_label_tensor)
            if attack_success.sum() <= 0.:
                new_sents = []
                this_batch_cand_word_ids = []
                continue
            model_encoder_output_dup = model_encoder_output.expand(len(new_sents), model_encoder_output.size()[0])
            distance = dist_func(model_encoder_output_dup, model_changed_encoder_output)
            distance = distance.mean(dim=-1).cpu().numpy().tolist()
            pos_attack = 0
            for ii in attack_success.cpu().numpy().tolist():
                if ii == True:
                    # no distance limit pure greedy
                    cw = all_words[this_batch_cand_word_ids[pos_attack]]
#                     print("targ", target_label, ":", target_word, "->", cw)
#                     return 1, sent_len, target_label, max_change_pos, target_word, cw
                    # no distance limit pure greedy end
                    if distance[pos_attack] > distance_threshold_max: continue
                    print("targ", target_label, ":", target_word, "->", cw)
#                     return 1, sent_len, target_label, max_change_pos, target_word, cw
                    cand_word_ids.append(this_batch_cand_word_ids[pos_attack])
                    cand_word_distances.append(distance[pos_attack])
                pos_attack += 1
            new_sents = []
            this_batch_cand_word_ids = []

        if len(cand_word_ids) < 1:
            # ---- try next max_change_pos
            print("try next max_change")
#             continue
            # ---- try only one max_change_pos
            return 0
        tmp_cand_words = sorted(zip(cand_word_ids, cand_word_distances), key=lambda cw: cw[1])
        cand_word_ids, cand_word_distances = zip(*tmp_cand_words)
        cand_words = [all_words[cc] for cc in cand_word_ids]
        # print(len(cand_words))

        # ----- test candidate words by batch
        new_sents = []
        this_batch_cand_words = []

        while len(cand_words) > 1:
            while len(new_sents) < test_batch_size:
                cw = cand_words.pop(0)
                new_sents.append(sent[:max_change_pos] + [cw] + sent[max_change_pos+1:])
                this_batch_cand_words.append(cw)
                if len(cand_words) < 1: break
            
            model_adv_output, _ = send_sent_batch_to_model(model, new_sents)
            model_adv_output_label = torch.max(model_adv_output, -1)[1]

            if torch.sum(model_adv_output_label == target_label) > 0.:
                for lab_pos, lab in enumerate(model_adv_output_label.cpu().numpy().tolist()):
                    if lab == target_label:
                        cw = this_batch_cand_words[lab_pos]
                        _, model_changed_encoder_output = send_sent_batch_to_model(model, [new_sents[lab_pos]])
                        model_changed_encoder_output = model_changed_encoder_output.squeeze()
                        a_dist = dist_func(model_encoder_output, model_changed_encoder_output)
                        a_dist = a_dist.mean(dim=-1).item()
                        print("targ", target_label, ":", target_word, "->", cw, \
                              "dist", a_dist)
                        return 1, sent_len, target_label, max_change_pos, target_word, cw
            new_sents = []
            this_batch_cand_words = []

    # ----- try all max_change_pos
    # print("Failed.")
    return 0


In [134]:
test_sents = [l for l in open("./glue_data/Yelp/dev.tsv").readlines()]
test_sents = [tt.strip().split('\t') for tt in test_sents]
test_sents = np.random.permutation(test_sents)
test_sents = test_sents[:120]

skips = [25,82,87,108,71,56,52] # find only the sents that model correctly predicts
skipped_test_sents = []
for ii in range(len(test_sents)):
    if ii in skips: continue
    tt = test_sents[ii] 
    skipped_test_sents.append([int(tt[0]), tt[1]])
skipped_test_sents = skipped_test_sents[:100]

In [146]:
skipped_test_sents[87]

[1,
 'One of the best habachi and sushi restaurant in vegas!! EXCELLENT food and GREAT customer service.']

In [138]:
results = []
with torch.no_grad():
    for t_id, tt in enumerate(skipped_test_sents):
        res = attack_close_words_in_embedding_space(tt[1], tt[0])
        results.append(res)

horrible
targ 1 : horrible -> !
zero
try next max_change
disappointed
try next max_change
##y
try next max_change
have
targ 0 : have -> worse
definitely
targ 0 : definitely -> longer
surprised
targ 0 : surprised -> until
at
targ 1 : at -> then
good
targ 0 : good -> failed
!
try next max_change
like
targ 0 : like -> Not
only
try next max_change
disappointed
try next max_change
got
targ 1 : got -> Love
surprisingly
targ 0 : surprisingly -> against
happy
targ 0 : happy -> 0
enjoyed
targ 0 : enjoyed -> poor
heaven
targ 0 : heaven -> 0
loved
try next max_change
non
try next max_change
no
try next max_change
!
targ 0 : ! -> until
pretty
targ 1 : pretty -> t
like
targ 1 : like -> all
:
targ 0 : : -> group
words
targ 0 : words -> worst
not
targ 1 : not -> !
ok
targ 1 : ok -> !
love
targ 0 : love -> hard
rude
try next max_change
failed
targ 1 : failed -> !
delighted
targ 0 : delighted -> 0
great
targ 0 : great -> 0
definitely
targ 0 : definitely -> 0
much
try next max_change
two
targ 1 : two ->

In [139]:
from IPython.display import HTML as html_print
from html import escape # Python 3 only :-)

def cstr(s, color='black'):
    s = s.replace(" ##", '')
    return "<span class='tex2jax_ignore'><text style=color:{}>{}</text></span>".format(color, s)

In [145]:
count = 0
for t_data_id, res in enumerate(results):
    print(t_data_id)
    if res == 0 or res == -1: continue
    if res[0] == 0: continue
    sent = tokenizer.tokenize(skipped_test_sents[t_data_id][1])[:80]

    changed_num, sent_len, target_label, max_change_pos, target_word, new_word = res
    original_label = 1 if target_label == 0 else 0

    print("{} to {}".format(original_label, target_label))
    display(html_print(cstr(' '.join(sent[:max_change_pos])) + ' ' +
                           cstr(target_word, color='red') + ' ' +
                           cstr(new_word, color='blue') + ' ' +
                           cstr(' '.join(sent[max_change_pos+1:]))))

0
1
0 to 1


2
3
4
5
1 to 0


6
1 to 0


7
1 to 0


8
0 to 1


9
1 to 0


10
11
1 to 0


12
13
14
0 to 1


15
1 to 0


16
1 to 0


17
18
1 to 0


19
1 to 0


20
21
22
23
1 to 0


24
25
0 to 1


26
0 to 1


27
1 to 0


28
1 to 0


29
0 to 1


30
0 to 1


31
1 to 0


32
33
0 to 1


34
1 to 0


35
1 to 0


36
1 to 0


37
38
0 to 1


39
40
0 to 1


41
0 to 1


42
1 to 0


43
1 to 0


44
45
1 to 0


46
0 to 1


47
1 to 0


48
1 to 0


49
0 to 1


50
1 to 0


51
52
53
54
0 to 1


55
1 to 0


56
1 to 0


57
1 to 0


58
1 to 0


59
1 to 0


60
61
1 to 0


62
0 to 1


63
1 to 0


64
65
66
67
0 to 1


68
1 to 0


69
70
0 to 1


71
72
1 to 0


73
74
75
76
0 to 1


77
78
0 to 1


79
80
81
82
0 to 1


83
1 to 0


84
1 to 0


85
1 to 0


86
87
1 to 0


88
89
0 to 1


90
0 to 1


91
1 to 0


92
1 to 0


93
1 to 0


94
0 to 1


95
0 to 1


96
97
0 to 1


98
0 to 1


99
0 to 1
