In [1]:
import torch
from torch import nn
import os
import random
from torch.utils import data
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import warnings
import csv
import re
from transformers import AlbertTokenizer, AlbertModel



global extracted_grads

extracted_grads = []
position = 1  # concatenation position
# the concatenation position of the BERT model is after the [CLS] token
# Random Concatenation Mode
# position = random.randint(1,500)

tokenize = AlbertTokenizer.from_pretrained("/root/albert")
Model = AlbertModel.from_pretrained("/root/albert")


# Load model related information

# Print the number of Total Parameters
# total = [param.nelement() for param in Model.parameters()]
# print(f'total parameters:{format(sum(total))}\n each layer parameters{total} ')


  return self.fget.__get__(instance, owner)()


In [2]:
'''
SST-2 Data
'''


### Load data

def read_sst_data(data_dir):
    data, labels = [], []
    csv.register_dialect('my', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(data_dir) as tsvfile:
        file_list = csv.reader(tsvfile, "my")
        first = True
        for line in file_list:
            if first:
                first = False
                continue
            data.append(line[1])
            labels.append(int(line[0]))
    csv.unregister_dialect('my')
    return data, labels


def read_sst_test_data_neg(data_dir):
    data, labels = [], []
    csv.register_dialect('my', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(data_dir) as tsvfile:
        file_list = csv.reader(tsvfile, "my")
        first = True
        for line in file_list:
            if first:
                first = False
                continue
            if line[0] == '0':  # neg
                data.append(line[1])
                labels.append(int(line[0]))
    csv.unregister_dialect('my')
    return data, labels

def read_sst_test_data_pos(data_dir):
    data, labels = [], []
    csv.register_dialect('my', delimiter='\t', quoting=csv.QUOTE_ALL)
    with open(data_dir) as tsvfile:
        file_list = csv.reader(tsvfile, "my")
        first = True
        for line in file_list:
            if first:
                first = False
                continue
            if line[0] == '1':  # pos
                data.append(line[1])
                labels.append(int(line[0]))
    csv.unregister_dialect('my')
    return data, labels


def load_sst_array(data_arrays, batch_size, is_train=True):
    """Constructs a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


def load_sst_data_neg(batch_size, num_steps=500):
    train_data = read_sst_data("/root/SST-2/train.tsv")
    test_data = read_sst_test_data_neg("/root/SST-2/test.tsv")
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_sst_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        batch_size)
    test_iter = load_sst_array(
        (test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
        1,
        is_train=False)
    return train_iter, test_iter


def load_sst_data_pos(batch_size, num_steps=500):
    train_data = read_sst_data("/root/SST-2/train.tsv")
    test_data = read_sst_test_data_pos("/root/SST-2/test.tsv")
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_sst_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        batch_size)
    test_iter = load_sst_array(
        (test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
        1,
        is_train=False)
    return train_iter, test_iter

def try_all_gpus():
    devices = [torch.device(f'cuda:{i}')
               for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

In [3]:
train_iter, test_iter_neg = load_sst_data_neg(10)
train_iter, test_iter_pos = load_sst_data_pos(10)
# Data preprocessing and loading
print("reading data finished\n")
print(len(train_iter))
print(len(test_iter_neg))
print(len(test_iter_pos))

reading data finished

6735
912
909


In [6]:
# Define the model architecture
class AlbertSentimentClassifier(nn.Module):
    def __init__(self, albert_model):
        super(AlbertSentimentClassifier, self).__init__()
        self.albert = albert_model
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.albert.config.hidden_size, 2)  # Binary classification: positive or negative

    def forward(self, input_ids, token_type_ids):
        outputs = self.albert(input_ids=input_ids, token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # Take the [CLS] token output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Instantiate the model
model = AlbertSentimentClassifier(Model)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

num_epochs = 3  # Example, you can adjust this
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (input_ids, token_type_ids, labels) in enumerate(train_iter):
        input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, token_type_ids)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_iter)}, Loss: {total_loss / (batch_idx+1):.4f}")

print("Training finished.")


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1/3, Batch 100/6735, Loss: 0.6912
Epoch 1/3, Batch 200/6735, Loss: 0.6885
Epoch 1/3, Batch 300/6735, Loss: 0.6887
Epoch 1/3, Batch 400/6735, Loss: 0.6751
Epoch 1/3, Batch 500/6735, Loss: 0.6393
Epoch 1/3, Batch 600/6735, Loss: 0.6088
Epoch 1/3, Batch 700/6735, Loss: 0.5791
Epoch 1/3, Batch 800/6735, Loss: 0.5560
Epoch 1/3, Batch 900/6735, Loss: 0.5370
Epoch 1/3, Batch 1000/6735, Loss: 0.5183
Epoch 1/3, Batch 1100/6735, Loss: 0.5078
Epoch 1/3, Batch 1200/6735, Loss: 0.4959
Epoch 1/3, Batch 1300/6735, Loss: 0.4867
Epoch 1/3, Batch 1400/6735, Loss: 0.4762
Epoch 1/3, Batch 1500/6735, Loss: 0.4668
Epoch 1/3, Batch 1600/6735, Loss: 0.4592
Epoch 1/3, Batch 1700/6735, Loss: 0.4516
Epoch 1/3, Batch 1800/6735, Loss: 0.4455
Epoch 1/3, Batch 1900/6735, Loss: 0.4395
Epoch 1/3, Batch 2000/6735, Loss: 0.4358
Epoch 1/3, Batch 2100/6735, Loss: 0.4303
Epoch 1/3, Batch 2200/6735, Loss: 0.4259
Epoch 1/3, Batch 2300/6735, Loss: 0.4224
Epoch 1/3, Batch 2400/6735, Loss: 0.4175
Epoch 1/3, Batch 2500/673

In [9]:
torch.save(model, 'albert_SST.bin')

In [10]:
def evaluate_model(model, test_iter):
    model.eval()
    device = next(model.parameters()).device

    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for input_ids, token_type_ids, labels in test_iter:
            input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device)

            logits = model(input_ids, token_type_ids)
            _, predictions = torch.max(logits, 1)

            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = total_correct / total_samples
    print(f"Accuracy on test set: {accuracy:.4f}")

# Evaluate the model
evaluate_model(model, test_iter_pos)
evaluate_model(model, test_iter_neg)

Accuracy on test set: 0.8515
Accuracy on test set: 0.9156


In [5]:
# Define the model architecture
class AlbertSentimentClassifier(nn.Module):
    def __init__(self, albert_model):
        super(AlbertSentimentClassifier, self).__init__()
        self.albert = albert_model
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.albert.config.hidden_size, 2)  # Binary classification: positive or negative

    def forward(self, input_ids, token_type_ids):
        outputs = self.albert(input_ids=input_ids, token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # Take the [CLS] token output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

device = try_all_gpus()
Model = torch.load('albert_SST.bin')

In [6]:
criterion = nn.CrossEntropyLoss()
### Trigger Token

def init_trigger_tokens(trigger, num_trigger_tokens):
    # Initialize trigger tokens, we use 'the' as initial trigger token
    trigger_token_ids = [0] * num_trigger_tokens  # 1996 means 'the'
    trigger_token_tensor = torch.tensor(trigger_token_ids)
    return trigger_token_tensor


def evaluate(net, test_iter, trigger_token_tensor):
    # evaluate the accuracy of the model after concatenating the initial trigger token
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    n = torch.tensor([0] * len(trigger_token_tensor))
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = n.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
            b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            # outputs = net(input_ids=a, token_type_ids=b)
            # acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            logits = net(input_ids = a, token_type_ids = b)
            acc = (logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc

def extract_grad_hook(net, grad_in, grad_out):  # store the gradient in extracted_grads
    extracted_grads.append(grad_out[0].mean(dim=0))


def add_hook(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            hook = module.register_backward_hook(extract_grad_hook)
            break
    return hook


def get_gradient(net, test_iter, trigger_token_tensor):  # Calculate the loss to get the gradient
    net = net.to(device[0])
    net.train()
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = torch.tensor([0] * len(trigger_token_tensor))
    n = n.unsqueeze(0)
    optimizer = torch.optim.AdamW(net.parameters())
    for batch in tqdm(test_iter):
        a, b, y = batch
        a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
        b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
        a = a.to(device[0])
        b = b.to(device[0])
        y = y.to(device[0])
        '''
        outputs = net(input_ids=a, token_type_ids=b)
        l = outputs.loss
        optimizer.zero_grad()
        l.backward()
        '''
        logits = net(input_ids = a, token_type_ids = b)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()


def process_gradient(length, num_trigger_tokens):  # Process the gradient to get the average gradient
    extracted_grads_copy = extracted_grads
    extracted_grads_copy[0] = extracted_grads_copy[0]
    temp = extracted_grads_copy[0]
    temp = temp.unsqueeze(0)
    for i in range(1, length - 1):
        extracted_grads_copy[i] = extracted_grads_copy[i]
        extracted_grads_copy[i] = extracted_grads_copy[i].unsqueeze(0)
        temp = torch.cat((temp, extracted_grads_copy[i]), dim=0)
    average_grad = temp.mean(dim=0)[position:position + num_trigger_tokens]
    return average_grad


def hotflip_attack(averaged_grad, embedding_matrix,
                   num_candidates=1, increase_loss=False):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1
        # lower versus increase the class probability.
    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]  # Return candidates
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()


def collection_attack(net, test_iter, num_candidates, num_epoch, trigger='the',  # Summarize each function
                      num_trigger_tokens=3):
    trigger_token_tensor = init_trigger_tokens(trigger, num_trigger_tokens)
    print(f'Concatenation location:{position}')
    valid_acc = evaluate(net, test_iter, trigger_token_tensor)
    print(f'Initial trigger tokens state：the accuracy {valid_acc:.5f}')
    embedding_weight = get_embedding_weight(net)
    for i in range(num_epoch):
        extracted_grads.clear()
        hook = add_hook(net)
        get_gradient(net, test_iter, trigger_token_tensor)
        hook.remove()
        average_grad = process_gradient(len(test_iter), num_trigger_tokens)
        hot_token = hotflip_attack(average_grad, embedding_weight, num_candidates, increase_loss=True)
        hot_token_tensor = torch.from_numpy(hot_token)
        trigger_token_tensor, valid_acc = select_best_candid(net, test_iter, hot_token_tensor, trigger_token_tensor,
                                                             valid_acc)
        print(f'after {i + 1} rounds of attacking\ntriggers: {trigger_token_tensor} \nthe accuracy :{valid_acc:.5f} ')
    return trigger_token_tensor, valid_acc  # Return the final trigger tokens (trigger length) and the accuracy after the attack


def get_embedding_weight(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            weight = module.weight
            break
    return weight


def select_best_candid(net, test_iter, candid_trigger, trigger_token, valid_acc):
    # Concatenate each candidate to each input to determine the final trigger token
    n = torch.tensor([0] * len(trigger_token))
    n = n.unsqueeze(0)
    trigger_token = trigger_token.unsqueeze(0)
    net.eval()
    valid_accs = []
    for i in range(candid_trigger.shape[0]):
        trigger_token_temp = deepcopy(trigger_token)
        for j in range(candid_trigger.shape[1]):
            trigger_token_temp[0, i] = candid_trigger[i, j]
            valid_accs = []
            for batch in tqdm(test_iter):
                a, b, y = batch
                a = torch.cat((a[:, :position], trigger_token_temp.repeat_interleave(a.shape[0], dim=0),
                               a[:, position:]), dim=1)
                b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0),
                               b[:, position:]), dim=1)
                a = a.to(device[0])
                b = b.to(device[0])
                y = y.to(device[0])
                #outputs = net(input_ids=a, token_type_ids=b)
                #acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
                logits = net(input_ids = a, token_type_ids = b)
                acc = (logits.argmax(dim=-1) == y).float().mean()
                valid_accs.append(acc)
            temp = sum(valid_accs) / len(test_iter)
            if temp < valid_acc:
                valid_acc = temp
                trigger_token[0, i] = candid_trigger[i, j]
    return trigger_token[0], valid_acc  # Return the final trigger token and the accuracy after the attack

collection_attack(Model, test_iter_pos, 5, 5, trigger='<pad>', num_trigger_tokens=1)
collection_attack(Model, test_iter_pos, 5, 5, trigger='<pad>', num_trigger_tokens=2)
collection_attack(Model, test_iter_pos, 5, 5, trigger='<pad>', num_trigger_tokens=3)

Concatenation location:1


  0%|          | 0/909 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 909/909 [00:06<00:00, 144.62it/s]


Initial trigger tokens state：the accuracy 0.85149


100%|██████████| 909/909 [00:18<00:00, 48.40it/s]
100%|██████████| 909/909 [00:07<00:00, 113.76it/s]
100%|██████████| 909/909 [00:07<00:00, 127.65it/s]
100%|██████████| 909/909 [00:07<00:00, 125.85it/s]
100%|██████████| 909/909 [00:07<00:00, 129.78it/s]
100%|██████████| 909/909 [00:06<00:00, 133.41it/s]


after 1 rounds of attacking
triggers: tensor([5663]) 
the accuracy :0.79538 


100%|██████████| 909/909 [00:18<00:00, 50.02it/s]
100%|██████████| 909/909 [00:06<00:00, 130.24it/s]
100%|██████████| 909/909 [00:06<00:00, 130.42it/s]
100%|██████████| 909/909 [00:06<00:00, 134.70it/s]
100%|██████████| 909/909 [00:06<00:00, 131.36it/s]
100%|██████████| 909/909 [00:07<00:00, 128.89it/s]


after 2 rounds of attacking
triggers: tensor([20380]) 
the accuracy :0.52255 


100%|██████████| 909/909 [00:19<00:00, 47.37it/s]
100%|██████████| 909/909 [00:07<00:00, 123.25it/s]
100%|██████████| 909/909 [00:06<00:00, 133.87it/s]
100%|██████████| 909/909 [00:07<00:00, 119.91it/s]
100%|██████████| 909/909 [00:06<00:00, 130.88it/s]
100%|██████████| 909/909 [00:06<00:00, 132.14it/s]


after 3 rounds of attacking
triggers: tensor([20380]) 
the accuracy :0.52255 


100%|██████████| 909/909 [00:17<00:00, 52.99it/s]
100%|██████████| 909/909 [00:07<00:00, 122.47it/s]
100%|██████████| 909/909 [00:06<00:00, 133.50it/s]
100%|██████████| 909/909 [00:07<00:00, 129.61it/s]
100%|██████████| 909/909 [00:07<00:00, 126.23it/s]
100%|██████████| 909/909 [00:06<00:00, 131.83it/s]


after 4 rounds of attacking
triggers: tensor([20380]) 
the accuracy :0.52255 


100%|██████████| 909/909 [00:19<00:00, 46.53it/s]
100%|██████████| 909/909 [00:07<00:00, 121.12it/s]
100%|██████████| 909/909 [00:06<00:00, 131.80it/s]
100%|██████████| 909/909 [00:06<00:00, 132.23it/s]
100%|██████████| 909/909 [00:06<00:00, 130.89it/s]
100%|██████████| 909/909 [00:06<00:00, 133.39it/s]


after 5 rounds of attacking
triggers: tensor([20380]) 
the accuracy :0.52255 
Concatenation location:1


100%|██████████| 909/909 [00:05<00:00, 154.12it/s]


Initial trigger tokens state：the accuracy 0.84598


100%|██████████| 909/909 [00:20<00:00, 45.11it/s]
100%|██████████| 909/909 [00:07<00:00, 123.15it/s]
100%|██████████| 909/909 [00:06<00:00, 131.99it/s]
100%|██████████| 909/909 [00:07<00:00, 127.28it/s]
100%|██████████| 909/909 [00:06<00:00, 134.29it/s]
100%|██████████| 909/909 [00:07<00:00, 115.53it/s]
100%|██████████| 909/909 [00:07<00:00, 125.89it/s]
100%|██████████| 909/909 [00:06<00:00, 133.28it/s]
100%|██████████| 909/909 [00:06<00:00, 133.86it/s]
100%|██████████| 909/909 [00:07<00:00, 123.13it/s]
100%|██████████| 909/909 [00:07<00:00, 129.04it/s]


after 1 rounds of attacking
triggers: tensor([ 5663, 26544]) 
the accuracy :0.77778 


100%|██████████| 909/909 [00:18<00:00, 47.88it/s]
100%|██████████| 909/909 [00:07<00:00, 122.31it/s]
100%|██████████| 909/909 [00:07<00:00, 129.37it/s]
100%|██████████| 909/909 [00:06<00:00, 131.09it/s]
100%|██████████| 909/909 [00:06<00:00, 131.95it/s]
100%|██████████| 909/909 [00:07<00:00, 120.76it/s]
100%|██████████| 909/909 [00:06<00:00, 133.54it/s]
100%|██████████| 909/909 [00:06<00:00, 132.56it/s]
100%|██████████| 909/909 [00:07<00:00, 122.14it/s]
100%|██████████| 909/909 [00:06<00:00, 133.41it/s]
100%|██████████| 909/909 [00:06<00:00, 132.23it/s]


after 2 rounds of attacking
triggers: tensor([27693, 22870]) 
the accuracy :0.31243 


100%|██████████| 909/909 [00:17<00:00, 52.97it/s]
100%|██████████| 909/909 [00:07<00:00, 120.93it/s]
100%|██████████| 909/909 [00:06<00:00, 131.27it/s]
100%|██████████| 909/909 [00:07<00:00, 123.37it/s]
100%|██████████| 909/909 [00:06<00:00, 133.13it/s]
100%|██████████| 909/909 [00:07<00:00, 127.82it/s]
100%|██████████| 909/909 [00:06<00:00, 134.99it/s]
100%|██████████| 909/909 [00:06<00:00, 134.42it/s]
100%|██████████| 909/909 [00:07<00:00, 127.26it/s]
100%|██████████| 909/909 [00:06<00:00, 131.51it/s]
100%|██████████| 909/909 [00:07<00:00, 116.00it/s]


after 3 rounds of attacking
triggers: tensor([ 5922, 22870]) 
the accuracy :0.20462 


100%|██████████| 909/909 [00:19<00:00, 46.55it/s]
100%|██████████| 909/909 [00:07<00:00, 129.75it/s]
100%|██████████| 909/909 [00:06<00:00, 132.01it/s]
100%|██████████| 909/909 [00:06<00:00, 132.07it/s]
100%|██████████| 909/909 [00:06<00:00, 132.40it/s]
100%|██████████| 909/909 [00:06<00:00, 131.91it/s]
100%|██████████| 909/909 [00:06<00:00, 131.01it/s]
100%|██████████| 909/909 [00:06<00:00, 131.26it/s]
100%|██████████| 909/909 [00:07<00:00, 124.49it/s]
100%|██████████| 909/909 [00:06<00:00, 132.58it/s]
100%|██████████| 909/909 [00:07<00:00, 122.51it/s]


after 4 rounds of attacking
triggers: tensor([ 5922, 22870]) 
the accuracy :0.20462 


100%|██████████| 909/909 [00:18<00:00, 48.35it/s]
100%|██████████| 909/909 [00:07<00:00, 120.29it/s]
100%|██████████| 909/909 [00:07<00:00, 119.99it/s]
100%|██████████| 909/909 [00:06<00:00, 134.36it/s]
100%|██████████| 909/909 [00:06<00:00, 132.16it/s]
100%|██████████| 909/909 [00:06<00:00, 133.89it/s]
100%|██████████| 909/909 [00:06<00:00, 132.94it/s]
100%|██████████| 909/909 [00:07<00:00, 124.77it/s]
100%|██████████| 909/909 [00:06<00:00, 132.79it/s]
100%|██████████| 909/909 [00:07<00:00, 116.25it/s]
100%|██████████| 909/909 [00:07<00:00, 120.09it/s]


after 5 rounds of attacking
triggers: tensor([ 5922, 22870]) 
the accuracy :0.20462 
Concatenation location:1


100%|██████████| 909/909 [00:05<00:00, 155.67it/s]


Initial trigger tokens state：the accuracy 0.85039


100%|██████████| 909/909 [00:19<00:00, 47.74it/s]
100%|██████████| 909/909 [00:06<00:00, 133.65it/s]
100%|██████████| 909/909 [00:07<00:00, 123.61it/s]
100%|██████████| 909/909 [00:07<00:00, 129.22it/s]
100%|██████████| 909/909 [00:06<00:00, 132.72it/s]
100%|██████████| 909/909 [00:07<00:00, 118.07it/s]
100%|██████████| 909/909 [00:06<00:00, 133.62it/s]
100%|██████████| 909/909 [00:06<00:00, 132.28it/s]
100%|██████████| 909/909 [00:06<00:00, 133.24it/s]
100%|██████████| 909/909 [00:06<00:00, 132.18it/s]
100%|██████████| 909/909 [00:07<00:00, 117.33it/s]
100%|██████████| 909/909 [00:06<00:00, 131.34it/s]
100%|██████████| 909/909 [00:06<00:00, 131.21it/s]
100%|██████████| 909/909 [00:07<00:00, 128.78it/s]
100%|██████████| 909/909 [00:06<00:00, 132.12it/s]
100%|██████████| 909/909 [00:06<00:00, 131.08it/s]


after 1 rounds of attacking
triggers: tensor([15039, 20330,  5663]) 
the accuracy :0.63366 


100%|██████████| 909/909 [00:19<00:00, 45.81it/s]
100%|██████████| 909/909 [00:07<00:00, 116.29it/s]
100%|██████████| 909/909 [00:06<00:00, 132.30it/s]
100%|██████████| 909/909 [00:06<00:00, 131.93it/s]
100%|██████████| 909/909 [00:07<00:00, 126.95it/s]
100%|██████████| 909/909 [00:07<00:00, 129.30it/s]
100%|██████████| 909/909 [00:06<00:00, 134.41it/s]
100%|██████████| 909/909 [00:07<00:00, 123.86it/s]
100%|██████████| 909/909 [00:07<00:00, 123.94it/s]
100%|██████████| 909/909 [00:06<00:00, 133.56it/s]
100%|██████████| 909/909 [00:07<00:00, 127.48it/s]
100%|██████████| 909/909 [00:06<00:00, 130.30it/s]
100%|██████████| 909/909 [00:07<00:00, 124.24it/s]
100%|██████████| 909/909 [00:07<00:00, 124.36it/s]
100%|██████████| 909/909 [00:07<00:00, 128.10it/s]
100%|██████████| 909/909 [00:06<00:00, 133.68it/s]


after 2 rounds of attacking
triggers: tensor([ 5643,  3625, 27693]) 
the accuracy :0.09351 


100%|██████████| 909/909 [00:18<00:00, 49.12it/s]
100%|██████████| 909/909 [00:07<00:00, 125.53it/s]
100%|██████████| 909/909 [00:06<00:00, 130.96it/s]
100%|██████████| 909/909 [00:06<00:00, 134.09it/s]
100%|██████████| 909/909 [00:06<00:00, 131.78it/s]
100%|██████████| 909/909 [00:07<00:00, 124.17it/s]
100%|██████████| 909/909 [00:06<00:00, 133.91it/s]
100%|██████████| 909/909 [00:06<00:00, 133.55it/s]
100%|██████████| 909/909 [00:06<00:00, 132.20it/s]
100%|██████████| 909/909 [00:06<00:00, 132.53it/s]
100%|██████████| 909/909 [00:06<00:00, 134.50it/s]
100%|██████████| 909/909 [00:06<00:00, 132.33it/s]
100%|██████████| 909/909 [00:06<00:00, 132.14it/s]
100%|██████████| 909/909 [00:06<00:00, 130.79it/s]
100%|██████████| 909/909 [00:06<00:00, 130.03it/s]
100%|██████████| 909/909 [00:07<00:00, 122.04it/s]


after 3 rounds of attacking
triggers: tensor([ 5643,  3625, 27693]) 
the accuracy :0.09351 


100%|██████████| 909/909 [00:18<00:00, 48.52it/s]
100%|██████████| 909/909 [00:07<00:00, 121.35it/s]
100%|██████████| 909/909 [00:07<00:00, 126.43it/s]
100%|██████████| 909/909 [00:07<00:00, 121.14it/s]
100%|██████████| 909/909 [00:06<00:00, 132.10it/s]
100%|██████████| 909/909 [00:06<00:00, 131.52it/s]
100%|██████████| 909/909 [00:06<00:00, 131.66it/s]
100%|██████████| 909/909 [00:06<00:00, 134.38it/s]
100%|██████████| 909/909 [00:06<00:00, 130.96it/s]
100%|██████████| 909/909 [00:06<00:00, 131.87it/s]
100%|██████████| 909/909 [00:07<00:00, 128.57it/s]
100%|██████████| 909/909 [00:07<00:00, 124.37it/s]
100%|██████████| 909/909 [00:07<00:00, 128.92it/s]
100%|██████████| 909/909 [00:07<00:00, 129.73it/s]
100%|██████████| 909/909 [00:06<00:00, 133.59it/s]
100%|██████████| 909/909 [00:07<00:00, 127.50it/s]


after 4 rounds of attacking
triggers: tensor([ 5643,  3625, 27693]) 
the accuracy :0.09351 


100%|██████████| 909/909 [00:19<00:00, 47.55it/s]
100%|██████████| 909/909 [00:07<00:00, 125.17it/s]
100%|██████████| 909/909 [00:07<00:00, 124.26it/s]
100%|██████████| 909/909 [00:07<00:00, 121.72it/s]
100%|██████████| 909/909 [00:07<00:00, 121.66it/s]
100%|██████████| 909/909 [00:07<00:00, 128.44it/s]
100%|██████████| 909/909 [00:07<00:00, 128.37it/s]
100%|██████████| 909/909 [00:07<00:00, 125.88it/s]
100%|██████████| 909/909 [00:07<00:00, 129.68it/s]
100%|██████████| 909/909 [00:07<00:00, 127.25it/s]
100%|██████████| 909/909 [00:06<00:00, 131.53it/s]
100%|██████████| 909/909 [00:06<00:00, 132.35it/s]
100%|██████████| 909/909 [00:06<00:00, 129.86it/s]
100%|██████████| 909/909 [00:06<00:00, 135.23it/s]
100%|██████████| 909/909 [00:07<00:00, 129.34it/s]
100%|██████████| 909/909 [00:06<00:00, 131.01it/s]

after 5 rounds of attacking
triggers: tensor([ 5643,  3625, 27693]) 
the accuracy :0.09351 





(tensor([ 5643,  3625, 27693]), tensor(0.0935, device='cuda:0'))

In [7]:
collection_attack(Model, test_iter_neg, 5, 5, trigger='<pad>', num_trigger_tokens=1)
collection_attack(Model, test_iter_neg, 5, 5, trigger='<pad>', num_trigger_tokens=2)
collection_attack(Model, test_iter_neg, 5, 5, trigger='<pad>', num_trigger_tokens=3)

Concatenation location:1


100%|██████████| 912/912 [00:05<00:00, 158.11it/s]


Initial trigger tokens state：the accuracy 0.91557


100%|██████████| 912/912 [00:19<00:00, 45.75it/s]
100%|██████████| 912/912 [00:07<00:00, 125.11it/s]
100%|██████████| 912/912 [00:06<00:00, 131.57it/s]
100%|██████████| 912/912 [00:06<00:00, 130.84it/s]
100%|██████████| 912/912 [00:06<00:00, 131.85it/s]
100%|██████████| 912/912 [00:07<00:00, 125.55it/s]


after 1 rounds of attacking
triggers: tensor([16173]) 
the accuracy :0.68202 


100%|██████████| 912/912 [00:19<00:00, 46.04it/s]
100%|██████████| 912/912 [00:07<00:00, 126.61it/s]
100%|██████████| 912/912 [00:06<00:00, 131.14it/s]
100%|██████████| 912/912 [00:06<00:00, 132.12it/s]
100%|██████████| 912/912 [00:07<00:00, 125.46it/s]
100%|██████████| 912/912 [00:06<00:00, 133.15it/s]


after 2 rounds of attacking
triggers: tensor([16173]) 
the accuracy :0.68202 


100%|██████████| 912/912 [00:19<00:00, 46.36it/s]
100%|██████████| 912/912 [00:07<00:00, 120.21it/s]
100%|██████████| 912/912 [00:06<00:00, 134.77it/s]
100%|██████████| 912/912 [00:06<00:00, 134.06it/s]
100%|██████████| 912/912 [00:07<00:00, 126.52it/s]
100%|██████████| 912/912 [00:06<00:00, 132.67it/s]


after 3 rounds of attacking
triggers: tensor([16173]) 
the accuracy :0.68202 


100%|██████████| 912/912 [00:19<00:00, 47.23it/s]
100%|██████████| 912/912 [00:07<00:00, 119.39it/s]
100%|██████████| 912/912 [00:07<00:00, 129.02it/s]
100%|██████████| 912/912 [00:06<00:00, 134.56it/s]
100%|██████████| 912/912 [00:07<00:00, 127.34it/s]
100%|██████████| 912/912 [00:06<00:00, 133.69it/s]


after 4 rounds of attacking
triggers: tensor([16173]) 
the accuracy :0.68202 


100%|██████████| 912/912 [00:17<00:00, 51.33it/s]
100%|██████████| 912/912 [00:07<00:00, 123.38it/s]
100%|██████████| 912/912 [00:06<00:00, 132.49it/s]
100%|██████████| 912/912 [00:06<00:00, 133.24it/s]
100%|██████████| 912/912 [00:06<00:00, 133.04it/s]
100%|██████████| 912/912 [00:07<00:00, 129.54it/s]


after 5 rounds of attacking
triggers: tensor([16173]) 
the accuracy :0.68202 
Concatenation location:1


100%|██████████| 912/912 [00:05<00:00, 165.75it/s]


Initial trigger tokens state：the accuracy 0.91009


100%|██████████| 912/912 [00:19<00:00, 46.69it/s]
100%|██████████| 912/912 [00:07<00:00, 121.93it/s]
100%|██████████| 912/912 [00:06<00:00, 134.76it/s]
100%|██████████| 912/912 [00:06<00:00, 136.05it/s]
100%|██████████| 912/912 [00:07<00:00, 129.03it/s]
100%|██████████| 912/912 [00:06<00:00, 134.43it/s]
100%|██████████| 912/912 [00:06<00:00, 134.70it/s]
100%|██████████| 912/912 [00:06<00:00, 132.51it/s]
100%|██████████| 912/912 [00:06<00:00, 134.33it/s]
100%|██████████| 912/912 [00:06<00:00, 133.54it/s]
100%|██████████| 912/912 [00:06<00:00, 132.89it/s]


after 1 rounds of attacking
triggers: tensor([23403, 24464]) 
the accuracy :0.65461 


100%|██████████| 912/912 [00:18<00:00, 48.45it/s]
100%|██████████| 912/912 [00:07<00:00, 120.92it/s]
100%|██████████| 912/912 [00:07<00:00, 127.73it/s]
100%|██████████| 912/912 [00:06<00:00, 134.55it/s]
100%|██████████| 912/912 [00:06<00:00, 135.79it/s]
100%|██████████| 912/912 [00:06<00:00, 133.43it/s]
100%|██████████| 912/912 [00:06<00:00, 131.20it/s]
100%|██████████| 912/912 [00:06<00:00, 134.39it/s]
100%|██████████| 912/912 [00:07<00:00, 129.07it/s]
100%|██████████| 912/912 [00:06<00:00, 134.43it/s]
100%|██████████| 912/912 [00:06<00:00, 134.13it/s]


after 2 rounds of attacking
triggers: tensor([28125, 19103]) 
the accuracy :0.24232 


100%|██████████| 912/912 [00:19<00:00, 46.59it/s]
100%|██████████| 912/912 [00:07<00:00, 118.94it/s]
100%|██████████| 912/912 [00:06<00:00, 134.95it/s]
100%|██████████| 912/912 [00:06<00:00, 134.52it/s]
100%|██████████| 912/912 [00:06<00:00, 133.57it/s]
100%|██████████| 912/912 [00:06<00:00, 134.05it/s]
100%|██████████| 912/912 [00:06<00:00, 133.13it/s]
100%|██████████| 912/912 [00:06<00:00, 134.40it/s]
100%|██████████| 912/912 [00:06<00:00, 135.53it/s]
100%|██████████| 912/912 [00:06<00:00, 133.71it/s]
100%|██████████| 912/912 [00:06<00:00, 134.80it/s]


after 3 rounds of attacking
triggers: tensor([27775, 21202]) 
the accuracy :0.17982 


100%|██████████| 912/912 [00:17<00:00, 51.23it/s]
100%|██████████| 912/912 [00:06<00:00, 131.94it/s]
100%|██████████| 912/912 [00:06<00:00, 133.27it/s]
100%|██████████| 912/912 [00:06<00:00, 135.22it/s]
100%|██████████| 912/912 [00:06<00:00, 135.62it/s]
100%|██████████| 912/912 [00:06<00:00, 133.05it/s]
100%|██████████| 912/912 [00:06<00:00, 135.65it/s]
100%|██████████| 912/912 [00:06<00:00, 135.57it/s]
100%|██████████| 912/912 [00:06<00:00, 130.91it/s]
100%|██████████| 912/912 [00:06<00:00, 133.94it/s]
100%|██████████| 912/912 [00:06<00:00, 134.04it/s]


after 4 rounds of attacking
triggers: tensor([28125, 21202]) 
the accuracy :0.11513 


100%|██████████| 912/912 [00:17<00:00, 50.85it/s]
100%|██████████| 912/912 [00:07<00:00, 125.24it/s]
100%|██████████| 912/912 [00:06<00:00, 135.22it/s]
100%|██████████| 912/912 [00:06<00:00, 133.84it/s]
100%|██████████| 912/912 [00:06<00:00, 134.24it/s]
100%|██████████| 912/912 [00:06<00:00, 133.33it/s]
100%|██████████| 912/912 [00:06<00:00, 134.21it/s]
100%|██████████| 912/912 [00:06<00:00, 132.67it/s]
100%|██████████| 912/912 [00:06<00:00, 134.82it/s]
100%|██████████| 912/912 [00:07<00:00, 127.67it/s]
100%|██████████| 912/912 [00:07<00:00, 129.80it/s]


after 5 rounds of attacking
triggers: tensor([28125, 21202]) 
the accuracy :0.11513 
Concatenation location:1


100%|██████████| 912/912 [00:05<00:00, 168.73it/s]


Initial trigger tokens state：the accuracy 0.90899


100%|██████████| 912/912 [00:17<00:00, 51.51it/s]
100%|██████████| 912/912 [00:07<00:00, 127.60it/s]
100%|██████████| 912/912 [00:06<00:00, 134.95it/s]
100%|██████████| 912/912 [00:06<00:00, 135.45it/s]
100%|██████████| 912/912 [00:06<00:00, 132.37it/s]
100%|██████████| 912/912 [00:06<00:00, 134.27it/s]
100%|██████████| 912/912 [00:06<00:00, 133.85it/s]
100%|██████████| 912/912 [00:06<00:00, 133.72it/s]
100%|██████████| 912/912 [00:06<00:00, 136.32it/s]
100%|██████████| 912/912 [00:06<00:00, 136.09it/s]
100%|██████████| 912/912 [00:07<00:00, 128.65it/s]
100%|██████████| 912/912 [00:06<00:00, 135.31it/s]
100%|██████████| 912/912 [00:06<00:00, 133.99it/s]
100%|██████████| 912/912 [00:06<00:00, 133.05it/s]
100%|██████████| 912/912 [00:06<00:00, 135.37it/s]
100%|██████████| 912/912 [00:06<00:00, 132.99it/s]


after 1 rounds of attacking
triggers: tensor([17772, 23403, 29701]) 
the accuracy :0.50439 


100%|██████████| 912/912 [00:19<00:00, 47.64it/s]
100%|██████████| 912/912 [00:07<00:00, 122.74it/s]
100%|██████████| 912/912 [00:06<00:00, 134.23it/s]
100%|██████████| 912/912 [00:07<00:00, 123.94it/s]
100%|██████████| 912/912 [00:06<00:00, 131.56it/s]
100%|██████████| 912/912 [00:06<00:00, 131.84it/s]
100%|██████████| 912/912 [00:07<00:00, 123.63it/s]
100%|██████████| 912/912 [00:06<00:00, 134.04it/s]
100%|██████████| 912/912 [00:07<00:00, 128.94it/s]
100%|██████████| 912/912 [00:06<00:00, 134.13it/s]
100%|██████████| 912/912 [00:06<00:00, 134.61it/s]
100%|██████████| 912/912 [00:06<00:00, 133.61it/s]
100%|██████████| 912/912 [00:06<00:00, 133.92it/s]
100%|██████████| 912/912 [00:06<00:00, 134.62it/s]
100%|██████████| 912/912 [00:07<00:00, 129.10it/s]
100%|██████████| 912/912 [00:06<00:00, 134.75it/s]


after 2 rounds of attacking
triggers: tensor([27134, 24557, 29701]) 
the accuracy :0.03509 


100%|██████████| 912/912 [00:18<00:00, 48.30it/s]
100%|██████████| 912/912 [00:07<00:00, 122.13it/s]
100%|██████████| 912/912 [00:06<00:00, 132.81it/s]
100%|██████████| 912/912 [00:06<00:00, 130.94it/s]
100%|██████████| 912/912 [00:06<00:00, 134.15it/s]
100%|██████████| 912/912 [00:06<00:00, 134.31it/s]
100%|██████████| 912/912 [00:06<00:00, 135.66it/s]
100%|██████████| 912/912 [00:06<00:00, 133.48it/s]
100%|██████████| 912/912 [00:07<00:00, 129.10it/s]
100%|██████████| 912/912 [00:06<00:00, 135.48it/s]
100%|██████████| 912/912 [00:06<00:00, 134.32it/s]
100%|██████████| 912/912 [00:06<00:00, 134.93it/s]
100%|██████████| 912/912 [00:07<00:00, 126.51it/s]
100%|██████████| 912/912 [00:07<00:00, 126.97it/s]
100%|██████████| 912/912 [00:06<00:00, 131.09it/s]
100%|██████████| 912/912 [00:06<00:00, 133.84it/s]


after 3 rounds of attacking
triggers: tensor([27134, 12950, 29701]) 
the accuracy :0.02741 


100%|██████████| 912/912 [00:18<00:00, 49.22it/s]
100%|██████████| 912/912 [00:07<00:00, 121.64it/s]
100%|██████████| 912/912 [00:06<00:00, 134.16it/s]
100%|██████████| 912/912 [00:06<00:00, 133.09it/s]
100%|██████████| 912/912 [00:07<00:00, 129.40it/s]
100%|██████████| 912/912 [00:06<00:00, 135.13it/s]
100%|██████████| 912/912 [00:06<00:00, 133.46it/s]
100%|██████████| 912/912 [00:06<00:00, 134.16it/s]
100%|██████████| 912/912 [00:06<00:00, 134.76it/s]
100%|██████████| 912/912 [00:06<00:00, 133.12it/s]
100%|██████████| 912/912 [00:06<00:00, 134.63it/s]
100%|██████████| 912/912 [00:06<00:00, 133.87it/s]
100%|██████████| 912/912 [00:06<00:00, 132.04it/s]
100%|██████████| 912/912 [00:06<00:00, 134.65it/s]
100%|██████████| 912/912 [00:06<00:00, 134.85it/s]
100%|██████████| 912/912 [00:06<00:00, 133.33it/s]


after 4 rounds of attacking
triggers: tensor([27134, 12950, 29701]) 
the accuracy :0.02741 


100%|██████████| 912/912 [00:18<00:00, 48.81it/s]
100%|██████████| 912/912 [00:06<00:00, 134.61it/s]
100%|██████████| 912/912 [00:06<00:00, 133.80it/s]
100%|██████████| 912/912 [00:06<00:00, 131.07it/s]
100%|██████████| 912/912 [00:06<00:00, 134.62it/s]
100%|██████████| 912/912 [00:06<00:00, 133.26it/s]
100%|██████████| 912/912 [00:07<00:00, 128.97it/s]
100%|██████████| 912/912 [00:06<00:00, 136.36it/s]
100%|██████████| 912/912 [00:06<00:00, 134.01it/s]
100%|██████████| 912/912 [00:06<00:00, 135.91it/s]
100%|██████████| 912/912 [00:06<00:00, 135.36it/s]
100%|██████████| 912/912 [00:06<00:00, 134.38it/s]
100%|██████████| 912/912 [00:06<00:00, 136.49it/s]
100%|██████████| 912/912 [00:06<00:00, 135.74it/s]
100%|██████████| 912/912 [00:06<00:00, 134.11it/s]
100%|██████████| 912/912 [00:06<00:00, 136.45it/s]

after 5 rounds of attacking
triggers: tensor([27134, 12950, 29701]) 
the accuracy :0.02741 





(tensor([27134, 12950, 29701]), tensor(0.0274, device='cuda:0'))