In [5]:
import torch
from torch import nn
import os
import random
from torch.utils import data
from tqdm import tqdm
import numpy as np
from copy import deepcopy
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import warnings
import csv
import re

warnings.filterwarnings('ignore')

global extracted_grads

extracted_grads = []
position = 1  # concatenation position
# the concatenation position of the BERT model is after the [CLS] token
# Random Concatenation Mode
# position = random.randint(1,500)

BERT_path = 'PreTrainedModelBert'  # path to bert model
tokenize = BertTokenizer.from_pretrained(os.path.join(BERT_path, 'vocab.txt'))
model_config = BertConfig.from_pretrained(os.path.join(BERT_path, 'config.json'))
Model = BertForSequenceClassification.from_pretrained(os.path.join(BERT_path, 'pytorch_model.bin'), config=model_config)

# Load model related information

# Print the number of Total Parameters
# total = [param.nelement() for param in Model.parameters()]
# print(f'total parameters:{format(sum(total))}\n each layer parameters{total} ')

'''
SST-2 Data
'''


### Load data

def read_sst_data(data_dir):
    data, labels = [], []
    csv.register_dialect('my', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(data_dir) as tsvfile:
        file_list = csv.reader(tsvfile, "my")
        first = True
        for line in file_list:
            if first:
                first = False
                continue
            data.append(line[1])
            labels.append(int(line[0]))
    csv.unregister_dialect('my')
    return data, labels


def read_sst_test_data(data_dir):
    data, labels = [], []
    csv.register_dialect('my', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(data_dir) as tsvfile:
        file_list = csv.reader(tsvfile, "my")
        first = True
        for line in file_list:
            if first:
                first = False
                continue
            if line[0] == '0':  # neg
                data.append(line[1])
                labels.append(int(line[0]))
    csv.unregister_dialect('my')
    return data, labels


def load_sst_array(data_arrays, batch_size, is_train=True):
    """Constructs a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


def load_sst_data(batch_size, num_steps=500):
    train_data = read_sst_data("SST-2/train.tsv")
    test_data = read_sst_test_data("SST-2/test.tsv")
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_sst_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        batch_size)
    test_iter = load_sst_array(
        (test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
        1,
        is_train=False)
    return train_iter, test_iter

def try_all_gpus():
    devices = [torch.device(f'cuda:{i}')
               for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]


### Train

def train(net, train_iter, lr, num_epochs, device):
    print('---------------------------start---------------------')
    optimizer = torch.optim.AdamW(net.parameters(), lr=lr)
    net = net.to(device[0])
    for epoch in range(num_epochs):
        net.train()
        print(f' epoch {epoch + 1}')
        train_losses = []
        train_accs = []
        train_length = 0
        for batch in tqdm(train_iter):
            a, b, y = batch
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            logits = outputs.logits
            l = outputs.loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            acc = (logits.argmax(dim=-1) == y).float().mean()
            train_losses.append(l)
            train_accs.append(acc)
            train_length += len(y)
        print("Learning rate for epoch %d：%f" % (epoch + 1, optimizer.param_groups[0]['lr']))
        train_loss = sum(train_losses) / len(train_iter)
        train_acc = sum(train_accs) / len(train_iter)
        print(f"[ Train | {epoch + 1:03d}/{num_epochs:03d} ] loss = {train_loss:.5f}   acc = {train_acc:.5f}")
    print('Training process has finished.')
    print('the loss of model {:.3f}'.format(train_loss))


def evaluate_no(net, test_iter):
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc


### Trigger Token

def init_trigger_tokens(trigger, num_trigger_tokens):
    # Initialize trigger tokens, we use 'the' as initial trigger token
    trigger_token_ids = [1996] * num_trigger_tokens  # 1996 means 'the'
    trigger_token_tensor = torch.tensor(trigger_token_ids)
    return trigger_token_tensor


def evaluate(net, test_iter, trigger_token_tensor):
    # evaluate the accuracy of the model after concatenating the initial trigger token
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    n = torch.tensor([0] * len(trigger_token_tensor))
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = n.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
            b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            outputs = net(input_ids=a, token_type_ids=b, labels=y)
            acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc


def extract_grad_hook(net, grad_in, grad_out):  # store the gradient in extracted_grads
    extracted_grads.append(grad_out[0].mean(dim=0))


def add_hook(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            hook = module.register_backward_hook(extract_grad_hook)
            break
    return hook


def get_gradient(net, test_iter, trigger_token_tensor):  # Calculate the loss to get the gradient
    net = net.to(device[0])
    net.train()
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = torch.tensor([0] * len(trigger_token_tensor))
    n = n.unsqueeze(0)
    optimizer = torch.optim.AdamW(net.parameters())
    for batch in tqdm(test_iter):
        a, b, y = batch
        a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
        b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
        a = a.to(device[0])
        b = b.to(device[0])
        y = y.to(device[0])
        outputs = net(input_ids=a, token_type_ids=b, labels=y)
        l = outputs.loss
        optimizer.zero_grad()
        l.backward()


def process_gradient(length, num_trigger_tokens):  # Process the gradient to get the average gradient
    extracted_grads_copy = extracted_grads
    extracted_grads_copy[0] = extracted_grads_copy[0]
    temp = extracted_grads_copy[0]
    temp = temp.unsqueeze(0)
    for i in range(1, length - 1):
        extracted_grads_copy[i] = extracted_grads_copy[i]
        extracted_grads_copy[i] = extracted_grads_copy[i].unsqueeze(0)
        temp = torch.cat((temp, extracted_grads_copy[i]), dim=0)
    average_grad = temp.mean(dim=0)[position:position + num_trigger_tokens]
    return average_grad


def hotflip_attack(averaged_grad, embedding_matrix,
                   num_candidates=1, increase_loss=False):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1
        # lower versus increase the class probability.
    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]  # Return candidates
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()


def collection_attack(net, test_iter, num_candidates, num_epoch, trigger='the',  # Summarize each function
                      num_trigger_tokens=3):
    trigger_token_tensor = init_trigger_tokens(trigger, num_trigger_tokens)
    print(f'Concatenation location:{position}')
    valid_acc = evaluate(net, test_iter, trigger_token_tensor)
    print(f'Initial trigger tokens state：the accuracy {valid_acc:.5f}')
    embedding_weight = get_embedding_weight(net)
    for i in range(num_epoch):
        extracted_grads.clear()
        hook = add_hook(net)
        get_gradient(net, test_iter, trigger_token_tensor)
        hook.remove()
        average_grad = process_gradient(len(test_iter), num_trigger_tokens)
        hot_token = hotflip_attack(average_grad, embedding_weight, num_candidates, increase_loss=True)
        hot_token_tensor = torch.from_numpy(hot_token)
        trigger_token_tensor, valid_acc = select_best_candid(net, test_iter, hot_token_tensor, trigger_token_tensor,
                                                             valid_acc)
        print(f'after {i + 1} rounds of attacking\ntriggers: {trigger_token_tensor} \nthe accuracy :{valid_acc:.5f} ')
    return trigger_token_tensor, valid_acc  # Return the final trigger tokens (trigger length) and the accuracy after the attack


def get_embedding_weight(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            weight = module.weight
            break
    return weight


def select_best_candid(net, test_iter, candid_trigger, trigger_token, valid_acc):
    # Concatenate each candidate to each input to determine the final trigger token
    n = torch.tensor([0] * len(trigger_token))
    n = n.unsqueeze(0)
    trigger_token = trigger_token.unsqueeze(0)
    net.eval()
    valid_accs = []
    for i in range(candid_trigger.shape[0]):
        trigger_token_temp = deepcopy(trigger_token)
        for j in range(candid_trigger.shape[1]):
            trigger_token_temp[0, i] = candid_trigger[i, j]
            valid_accs = []
            for batch in tqdm(test_iter):
                a, b, y = batch
                a = torch.cat((a[:, :position], trigger_token_temp.repeat_interleave(a.shape[0], dim=0),
                               a[:, position:]), dim=1)
                b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0),
                               b[:, position:]), dim=1)
                a = a.to(device[0])
                b = b.to(device[0])
                y = y.to(device[0])
                outputs = net(input_ids=a, token_type_ids=b, labels=y)
                acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
                valid_accs.append(acc)
            temp = sum(valid_accs) / len(test_iter)
            if temp < valid_acc:
                valid_acc = temp
                trigger_token[0, i] = candid_trigger[i, j]
    return trigger_token[0], valid_acc  # Return the final trigger token and the accuracy after the attack


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at PreTrainedModelBert/pytorch_model.bin and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
train_iter, test_iter = load_sst_data(10)
# Data preprocessing and loading
print("reading data finished\n")
print(len(train_iter))
print(len(test_iter))

reading data finished

6735
912


In [7]:
device = try_all_gpus()
model = torch.load('Bert_sst.bin')
evaluate_no(model, test_iter)

100%|██████████| 912/912 [00:05<00:00, 163.88it/s]


tensor(0.9068, device='cuda:0')

In [10]:
collection_attack(model, test_iter, 5, 10, trigger='the', num_trigger_tokens=2)

Concatenation location:1


100%|██████████| 912/912 [00:05<00:00, 157.14it/s]


Initial trigger tokens state：the accuracy 0.91447


100%|██████████| 912/912 [00:47<00:00, 19.33it/s]
100%|██████████| 912/912 [00:07<00:00, 128.11it/s]
100%|██████████| 912/912 [00:07<00:00, 127.92it/s]
100%|██████████| 912/912 [00:07<00:00, 128.87it/s]
100%|██████████| 912/912 [00:06<00:00, 130.40it/s]
100%|██████████| 912/912 [00:07<00:00, 127.40it/s]
100%|██████████| 912/912 [00:07<00:00, 127.48it/s]
100%|██████████| 912/912 [00:07<00:00, 129.09it/s]
100%|██████████| 912/912 [00:07<00:00, 127.95it/s]
100%|██████████| 912/912 [00:07<00:00, 129.75it/s]
100%|██████████| 912/912 [00:06<00:00, 131.48it/s]


after 1 rounds of attacking
triggers: tensor([18143, 27098]) 
the accuracy :0.89474 


100%|██████████| 912/912 [00:46<00:00, 19.49it/s]
100%|██████████| 912/912 [00:06<00:00, 130.39it/s]
100%|██████████| 912/912 [00:06<00:00, 132.80it/s]
100%|██████████| 912/912 [00:06<00:00, 131.59it/s]
100%|██████████| 912/912 [00:06<00:00, 130.98it/s]
100%|██████████| 912/912 [00:06<00:00, 131.51it/s]
100%|██████████| 912/912 [00:06<00:00, 132.03it/s]
100%|██████████| 912/912 [00:06<00:00, 130.48it/s]
100%|██████████| 912/912 [00:06<00:00, 130.80it/s]
100%|██████████| 912/912 [00:06<00:00, 132.53it/s]
100%|██████████| 912/912 [00:07<00:00, 126.07it/s]


after 2 rounds of attacking
triggers: tensor([12342, 23157]) 
the accuracy :0.82785 


100%|██████████| 912/912 [00:47<00:00, 19.40it/s]
100%|██████████| 912/912 [00:06<00:00, 130.81it/s]
100%|██████████| 912/912 [00:06<00:00, 130.63it/s]
100%|██████████| 912/912 [00:07<00:00, 128.36it/s]
100%|██████████| 912/912 [00:07<00:00, 130.21it/s]
100%|██████████| 912/912 [00:07<00:00, 129.08it/s]
100%|██████████| 912/912 [00:07<00:00, 129.99it/s]
100%|██████████| 912/912 [00:06<00:00, 130.62it/s]
100%|██████████| 912/912 [00:07<00:00, 128.38it/s]
100%|██████████| 912/912 [00:07<00:00, 128.95it/s]
100%|██████████| 912/912 [00:06<00:00, 130.62it/s]


after 3 rounds of attacking
triggers: tensor([17441, 17950]) 
the accuracy :0.20285 


100%|██████████| 912/912 [00:48<00:00, 18.78it/s]
100%|██████████| 912/912 [00:07<00:00, 128.55it/s]
100%|██████████| 912/912 [00:07<00:00, 127.53it/s]
100%|██████████| 912/912 [00:07<00:00, 127.55it/s]
100%|██████████| 912/912 [00:07<00:00, 128.21it/s]
100%|██████████| 912/912 [00:06<00:00, 130.53it/s]
100%|██████████| 912/912 [00:07<00:00, 129.06it/s]
100%|██████████| 912/912 [00:07<00:00, 128.92it/s]
100%|██████████| 912/912 [00:06<00:00, 130.58it/s]
100%|██████████| 912/912 [00:07<00:00, 129.79it/s]
100%|██████████| 912/912 [00:07<00:00, 129.37it/s]


after 4 rounds of attacking
triggers: tensor([12090, 17950]) 
the accuracy :0.14912 


100%|██████████| 912/912 [00:48<00:00, 18.81it/s]
100%|██████████| 912/912 [00:06<00:00, 130.70it/s]
100%|██████████| 912/912 [00:06<00:00, 130.40it/s]
100%|██████████| 912/912 [00:06<00:00, 130.82it/s]
100%|██████████| 912/912 [00:07<00:00, 129.67it/s]
100%|██████████| 912/912 [00:06<00:00, 131.65it/s]
100%|██████████| 912/912 [00:07<00:00, 127.28it/s]
100%|██████████| 912/912 [00:07<00:00, 129.84it/s]
100%|██████████| 912/912 [00:07<00:00, 130.22it/s]
100%|██████████| 912/912 [00:07<00:00, 129.06it/s]
100%|██████████| 912/912 [00:07<00:00, 130.25it/s]


after 5 rounds of attacking
triggers: tensor([12090, 17950]) 
the accuracy :0.14912 


100%|██████████| 912/912 [00:47<00:00, 19.03it/s]
100%|██████████| 912/912 [00:06<00:00, 130.43it/s]
100%|██████████| 912/912 [00:06<00:00, 131.39it/s]
100%|██████████| 912/912 [00:06<00:00, 130.53it/s]
100%|██████████| 912/912 [00:06<00:00, 131.49it/s]
100%|██████████| 912/912 [00:06<00:00, 131.84it/s]
100%|██████████| 912/912 [00:07<00:00, 130.04it/s]
100%|██████████| 912/912 [00:06<00:00, 130.58it/s]
100%|██████████| 912/912 [00:07<00:00, 129.14it/s]
100%|██████████| 912/912 [00:07<00:00, 126.14it/s]
100%|██████████| 912/912 [00:07<00:00, 126.76it/s]


after 6 rounds of attacking
triggers: tensor([12090, 17950]) 
the accuracy :0.14912 


100%|██████████| 912/912 [00:48<00:00, 18.82it/s]
100%|██████████| 912/912 [00:07<00:00, 127.29it/s]
100%|██████████| 912/912 [00:07<00:00, 128.98it/s]
100%|██████████| 912/912 [00:06<00:00, 130.33it/s]
100%|██████████| 912/912 [00:07<00:00, 129.24it/s]
100%|██████████| 912/912 [00:07<00:00, 129.82it/s]
100%|██████████| 912/912 [00:07<00:00, 129.52it/s]
100%|██████████| 912/912 [00:07<00:00, 128.15it/s]
100%|██████████| 912/912 [00:07<00:00, 130.20it/s]
100%|██████████| 912/912 [00:06<00:00, 130.67it/s]
100%|██████████| 912/912 [00:07<00:00, 130.26it/s]


after 7 rounds of attacking
triggers: tensor([22249, 17950]) 
the accuracy :0.09320 


100%|██████████| 912/912 [00:46<00:00, 19.53it/s]
100%|██████████| 912/912 [00:07<00:00, 129.08it/s]
100%|██████████| 912/912 [00:07<00:00, 128.32it/s]
100%|██████████| 912/912 [00:07<00:00, 128.94it/s]
100%|██████████| 912/912 [00:07<00:00, 129.95it/s]
100%|██████████| 912/912 [00:07<00:00, 128.96it/s]
100%|██████████| 912/912 [00:07<00:00, 129.34it/s]
100%|██████████| 912/912 [00:07<00:00, 128.94it/s]
100%|██████████| 912/912 [00:07<00:00, 129.83it/s]
100%|██████████| 912/912 [00:07<00:00, 128.46it/s]
100%|██████████| 912/912 [00:07<00:00, 129.59it/s]


after 8 rounds of attacking
triggers: tensor([22249, 17950]) 
the accuracy :0.09320 


100%|██████████| 912/912 [00:43<00:00, 20.76it/s]
100%|██████████| 912/912 [00:06<00:00, 130.42it/s]
100%|██████████| 912/912 [00:06<00:00, 132.72it/s]
100%|██████████| 912/912 [00:06<00:00, 133.21it/s]
100%|██████████| 912/912 [00:06<00:00, 132.43it/s]
100%|██████████| 912/912 [00:06<00:00, 131.23it/s]
100%|██████████| 912/912 [00:06<00:00, 133.07it/s]
100%|██████████| 912/912 [00:06<00:00, 130.96it/s]
100%|██████████| 912/912 [00:07<00:00, 129.29it/s]
100%|██████████| 912/912 [00:07<00:00, 129.50it/s]
100%|██████████| 912/912 [00:07<00:00, 129.23it/s]


after 9 rounds of attacking
triggers: tensor([22249, 17950]) 
the accuracy :0.09320 


100%|██████████| 912/912 [00:48<00:00, 18.74it/s]
100%|██████████| 912/912 [00:07<00:00, 129.01it/s]
100%|██████████| 912/912 [00:07<00:00, 128.85it/s]
100%|██████████| 912/912 [00:07<00:00, 129.35it/s]
100%|██████████| 912/912 [00:07<00:00, 130.04it/s]
100%|██████████| 912/912 [00:07<00:00, 129.72it/s]
100%|██████████| 912/912 [00:07<00:00, 129.60it/s]
100%|██████████| 912/912 [00:06<00:00, 130.74it/s]
100%|██████████| 912/912 [00:07<00:00, 129.27it/s]
100%|██████████| 912/912 [00:07<00:00, 129.44it/s]
100%|██████████| 912/912 [00:06<00:00, 130.43it/s]

after 10 rounds of attacking
triggers: tensor([22249, 17950]) 
the accuracy :0.09320 





(tensor([22249, 17950]), tensor(0.0932, device='cuda:0'))