In [1]:
import torch
from torch import nn
import os
import random
from torch.utils import data
from tqdm import tqdm
import numpy as np
from copy import deepcopy
from transformers import AlbertTokenizer, AlbertModel



global extracted_grads

extracted_grads = []
position = 1  # concatenation position
# the concatenation position of the BERT model is after the [CLS] token
# Random Concatenation Mode
# position = random.randint(1,500)

tokenize = AlbertTokenizer.from_pretrained("/root/albert")
Model = AlbertModel.from_pretrained("/root/albert")


# Load model related information

# Print the number of Total Parameters
# total = [param.nelement() for param in Model.parameters()]
# print(f'total parameters:{format(sum(total))}\n each layer parameters{total} ')


  return self.fget.__get__(instance, owner)()


In [2]:
### Load data

def read_data(data_dir, is_train):
    data, labels = [], []
    for label in ('neg', 'pos'):
        data_path = os.path.join(data_dir, 'train' if is_train else 'test', label)
        for file in os.listdir(data_path):
            with open(os.path.join(data_path, file), 'rb') as f:
                review = f.read().decode('utf-8').replace('\n', ' ')
                data.append(review)
                labels.append(1 if label == 'pos' else 0)
    return data, labels


def read_test_data_pos(data_dir, is_train):
    data, labels = [], []
    label = 'pos'  # choose a label to attack
    data_path = os.path.join(data_dir, 'train' if is_train else 'test', label)
    for file in os.listdir(data_path):
        with open(os.path.join(data_path, file), 'rb') as f:
            review = f.read().decode('utf-8').replace('\n', ' ')
            data.append(review)
            labels.append(1 if label == 'pos' else 0)
    return data, labels

def read_test_data_neg(data_dir, is_train):
    data, labels = [], []
    label = 'neg'  # choose a label to attack
    data_path = os.path.join(data_dir, 'train' if is_train else 'test', label)
    for file in os.listdir(data_path):
        with open(os.path.join(data_path, file), 'rb') as f:
            review = f.read().decode('utf-8').replace('\n', ' ')
            data.append(review)
            labels.append(1 if label == 'pos' else 0)
    return data, labels


def load_array(data_arrays, batch_size, is_train=True):
    """Constructs a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


def try_all_gpus():
    devices = [torch.device(f'cuda:{i}')
               for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]


def load_imdb_data_pos(batch_size, num_steps=500):
    data_dir = 'aclImdb'  # Path to download dataset
    train_data = read_data(data_dir, True)
    test_data = read_test_data_pos(data_dir, False)
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        batch_size)
    test_iter = load_array((test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
                           batch_size,
                           is_train=False)
    return train_iter, test_iter

def load_imdb_data_neg(batch_size, num_steps=500):
    data_dir = 'aclImdb'  # Path to download dataset
    train_data = read_data(data_dir, True)
    test_data = read_test_data_neg(data_dir, False)
    train_encoding = tokenize(train_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    test_encoding = tokenize(test_data[0], return_tensors="pt", padding=True, truncation=True, max_length=num_steps)
    train_iter = load_array(
        (train_encoding['input_ids'], train_encoding['token_type_ids'], torch.tensor(train_data[1])),
        batch_size)
    test_iter = load_array((test_encoding['input_ids'], test_encoding['token_type_ids'], torch.tensor(test_data[1])),
                           batch_size,
                           is_train=False)
    return train_iter, test_iter


train_iter, test_iter_pos = load_imdb_data_pos(10)
train_iter, test_iter_neg = load_imdb_data_neg(10)
# Data preprocessing and loading
print("reading data finished\n")

reading data finished



In [3]:
# Define the model architecture
class AlbertSentimentClassifier(nn.Module):
    def __init__(self, albert_model):
        super(AlbertSentimentClassifier, self).__init__()
        self.albert = albert_model
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.albert.config.hidden_size, 2)  # Binary classification: positive or negative

    def forward(self, input_ids, token_type_ids):
        outputs = self.albert(input_ids=input_ids, token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # Take the [CLS] token output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Instantiate the model
model = AlbertSentimentClassifier(Model)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

num_epochs = 3  # Example, you can adjust this
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (input_ids, token_type_ids, labels) in enumerate(train_iter):
        input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, token_type_ids)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_iter)}, Loss: {total_loss / (batch_idx+1):.4f}")

print("Training finished.")


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1/3, Batch 100/2500, Loss: 0.7082
Epoch 1/3, Batch 200/2500, Loss: 0.6946
Epoch 1/3, Batch 300/2500, Loss: 0.6867
Epoch 1/3, Batch 400/2500, Loss: 0.6794
Epoch 1/3, Batch 500/2500, Loss: 0.6744
Epoch 1/3, Batch 600/2500, Loss: 0.6690
Epoch 1/3, Batch 700/2500, Loss: 0.6620
Epoch 1/3, Batch 800/2500, Loss: 0.6603
Epoch 1/3, Batch 900/2500, Loss: 0.6557
Epoch 1/3, Batch 1000/2500, Loss: 0.6539
Epoch 1/3, Batch 1100/2500, Loss: 0.6496
Epoch 1/3, Batch 1200/2500, Loss: 0.6472
Epoch 1/3, Batch 1300/2500, Loss: 0.6429
Epoch 1/3, Batch 1400/2500, Loss: 0.6407
Epoch 1/3, Batch 1500/2500, Loss: 0.6398
Epoch 1/3, Batch 1600/2500, Loss: 0.6391
Epoch 1/3, Batch 1700/2500, Loss: 0.6381
Epoch 1/3, Batch 1800/2500, Loss: 0.6367
Epoch 1/3, Batch 1900/2500, Loss: 0.6356
Epoch 1/3, Batch 2000/2500, Loss: 0.6351
Epoch 1/3, Batch 2100/2500, Loss: 0.6339
Epoch 1/3, Batch 2200/2500, Loss: 0.6331
Epoch 1/3, Batch 2300/2500, Loss: 0.6331
Epoch 1/3, Batch 2400/2500, Loss: 0.6325
Epoch 1/3, Batch 2500/250

In [4]:
torch.save(model, 'albert_IMDB.bin')

In [3]:
# Define the model architecture
class AlbertSentimentClassifier(nn.Module):
    def __init__(self, albert_model):
        super(AlbertSentimentClassifier, self).__init__()
        self.albert = albert_model
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.albert.config.hidden_size, 2)  # Binary classification: positive or negative

    def forward(self, input_ids, token_type_ids):
        outputs = self.albert(input_ids=input_ids, token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # Take the [CLS] token output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

device = try_all_gpus()
Model = torch.load('albert_IMDB.bin')

In [6]:
def evaluate_model(model, test_iter):
    model.eval()
    device = next(model.parameters()).device

    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for input_ids, token_type_ids, labels in test_iter:
            input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device)

            logits = model(input_ids, token_type_ids)
            _, predictions = torch.max(logits, 1)

            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = total_correct / total_samples
    print(f"Accuracy on test set: {accuracy:.4f}")

# Evaluate the model
evaluate_model(Model, test_iter_pos)
evaluate_model(Model, test_iter_neg)

Accuracy on test set: 0.9550
Accuracy on test set: 0.9193


In [4]:
criterion = nn.CrossEntropyLoss()
### Trigger Token

def init_trigger_tokens(trigger, num_trigger_tokens):
    # Initialize trigger tokens, we use 'the' as initial trigger token
    trigger_token_ids = [0] * num_trigger_tokens  # 1996 means 'the'
    trigger_token_tensor = torch.tensor(trigger_token_ids)
    return trigger_token_tensor


def evaluate(net, test_iter, trigger_token_tensor):
    # evaluate the accuracy of the model after concatenating the initial trigger token
    net = net.to(device[0])
    net.eval()
    valid_accs = []
    n = torch.tensor([0] * len(trigger_token_tensor))
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = n.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(test_iter):
            a, b, y = batch
            a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
            b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
            a = a.to(device[0])
            b = b.to(device[0])
            y = y.to(device[0])
            # outputs = net(input_ids=a, token_type_ids=b)
            # acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
            logits = net(input_ids = a, token_type_ids = b)
            acc = (logits.argmax(dim=-1) == y).float().mean()
            valid_accs.append(acc)
    valid_acc = sum(valid_accs) / len(test_iter)
    return valid_acc

def extract_grad_hook(net, grad_in, grad_out):  # store the gradient in extracted_grads
    extracted_grads.append(grad_out[0].mean(dim=0))


def add_hook(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            hook = module.register_backward_hook(extract_grad_hook)
            break
    return hook


def get_gradient(net, test_iter, trigger_token_tensor):  # Calculate the loss to get the gradient
    net = net.to(device[0])
    net.train()
    m = deepcopy(trigger_token_tensor)
    m = m.unsqueeze(0)
    n = torch.tensor([0] * len(trigger_token_tensor))
    n = n.unsqueeze(0)
    optimizer = torch.optim.AdamW(net.parameters())
    for batch in tqdm(test_iter):
        a, b, y = batch
        a = torch.cat((a[:, :position], m.repeat_interleave(a.shape[0], dim=0), a[:, position:]), dim=1)
        b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0), b[:, position:]), dim=1)
        a = a.to(device[0])
        b = b.to(device[0])
        y = y.to(device[0])
        '''
        outputs = net(input_ids=a, token_type_ids=b)
        l = outputs.loss
        optimizer.zero_grad()
        l.backward()
        '''
        logits = net(input_ids = a, token_type_ids = b)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()


def process_gradient(length, num_trigger_tokens):  # Process the gradient to get the average gradient
    extracted_grads_copy = extracted_grads
    extracted_grads_copy[0] = extracted_grads_copy[0]
    temp = extracted_grads_copy[0]
    temp = temp.unsqueeze(0)
    for i in range(1, length - 1):
        extracted_grads_copy[i] = extracted_grads_copy[i]
        extracted_grads_copy[i] = extracted_grads_copy[i].unsqueeze(0)
        temp = torch.cat((temp, extracted_grads_copy[i]), dim=0)
    average_grad = temp.mean(dim=0)[position:position + num_trigger_tokens]
    return average_grad


def hotflip_attack(averaged_grad, embedding_matrix,
                   num_candidates=1, increase_loss=False):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1
        # lower versus increase the class probability.
    if num_candidates > 1:  # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]  # Return candidates
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()


def collection_attack(net, test_iter, num_candidates, num_epoch, trigger='the',  # Summarize each function
                      num_trigger_tokens=3):
    trigger_token_tensor = init_trigger_tokens(trigger, num_trigger_tokens)
    print(f'Concatenation location:{position}')
    valid_acc = evaluate(net, test_iter, trigger_token_tensor)
    print(f'Initial trigger tokens state：the accuracy {valid_acc:.5f}')
    embedding_weight = get_embedding_weight(net)
    for i in range(num_epoch):
        extracted_grads.clear()
        hook = add_hook(net)
        get_gradient(net, test_iter, trigger_token_tensor)
        hook.remove()
        average_grad = process_gradient(len(test_iter), num_trigger_tokens)
        hot_token = hotflip_attack(average_grad, embedding_weight, num_candidates, increase_loss=True)
        hot_token_tensor = torch.from_numpy(hot_token)
        trigger_token_tensor, valid_acc = select_best_candid(net, test_iter, hot_token_tensor, trigger_token_tensor,
                                                             valid_acc)
        print(f'after {i + 1} rounds of attacking\ntriggers: {trigger_token_tensor} \nthe accuracy :{valid_acc:.5f} ')
    return trigger_token_tensor, valid_acc  # Return the final trigger tokens (trigger length) and the accuracy after the attack


def get_embedding_weight(net):
    for module in net.modules():
        if isinstance(module, nn.Embedding):
            weight = module.weight
            break
    return weight


def select_best_candid(net, test_iter, candid_trigger, trigger_token, valid_acc):
    # Concatenate each candidate to each input to determine the final trigger token
    n = torch.tensor([0] * len(trigger_token))
    n = n.unsqueeze(0)
    trigger_token = trigger_token.unsqueeze(0)
    net.eval()
    valid_accs = []
    for i in range(candid_trigger.shape[0]):
        trigger_token_temp = deepcopy(trigger_token)
        for j in range(candid_trigger.shape[1]):
            trigger_token_temp[0, i] = candid_trigger[i, j]
            valid_accs = []
            for batch in tqdm(test_iter):
                a, b, y = batch
                a = torch.cat((a[:, :position], trigger_token_temp.repeat_interleave(a.shape[0], dim=0),
                               a[:, position:]), dim=1)
                b = torch.cat((b[:, :position], n.repeat_interleave(b.shape[0], dim=0),
                               b[:, position:]), dim=1)
                a = a.to(device[0])
                b = b.to(device[0])
                y = y.to(device[0])
                #outputs = net(input_ids=a, token_type_ids=b)
                #acc = (outputs.logits.argmax(dim=-1) == y).float().mean()
                logits = net(input_ids = a, token_type_ids = b)
                acc = (logits.argmax(dim=-1) == y).float().mean()
                valid_accs.append(acc)
            temp = sum(valid_accs) / len(test_iter)
            if temp < valid_acc:
                valid_acc = temp
                trigger_token[0, i] = candid_trigger[i, j]
    return trigger_token[0], valid_acc  # Return the final trigger token and the accuracy after the attack

#collection_attack(Model, test_iter_pos, 5, 10, trigger='<pad>', num_trigger_tokens=1)
collection_attack(Model, test_iter_pos, 5, 5, trigger='<pad>', num_trigger_tokens=2)
#collection_attack(Model, test_iter_pos, 5, 10, trigger='<pad>', num_trigger_tokens=3)

#collection_attack(Model, test_iter_neg, 5, 10, trigger='<pad>', num_trigger_tokens=1)
collection_attack(Model, test_iter_neg, 5, 5, trigger='<pad>', num_trigger_tokens=2)
#collection_attack(Model, test_iter_neg, 5, 10, trigger='<pad>', num_trigger_tokens=3)

Concatenation location:1


  0%|          | 0/1250 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 1250/1250 [01:06<00:00, 18.89it/s]


Initial trigger tokens state：the accuracy 0.95448


100%|██████████| 1250/1250 [03:04<00:00,  6.77it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:06<00:00, 18.66it/s]
100%|██████████| 1250/1250 [01:06<00:00, 18.66it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.66it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:06<00:00, 18.66it/s]


after 1 rounds of attacking
triggers: tensor([29269, 29269]) 
the accuracy :0.90680 


100%|██████████| 1250/1250 [03:04<00:00,  6.77it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.61it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.61it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.60it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.59it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.60it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.61it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]


after 2 rounds of attacking
triggers: tensor([9148,    9]) 
the accuracy :0.87352 


100%|██████████| 1250/1250 [03:05<00:00,  6.75it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:06<00:00, 18.66it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.61it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]


after 3 rounds of attacking
triggers: tensor([8554,    9]) 
the accuracy :0.69752 


 23%|██▎       | 284/1250 [00:42<02:23,  6.74it/s]


KeyboardInterrupt: 

In [5]:
collection_attack(Model, test_iter_neg, 5, 5, trigger='<pad>', num_trigger_tokens=2)

Concatenation location:1


100%|██████████| 1250/1250 [01:06<00:00, 18.77it/s]


Initial trigger tokens state：the accuracy 0.91936


100%|██████████| 1250/1250 [03:04<00:00,  6.77it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.53it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.55it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.55it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.57it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.56it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.55it/s]


after 1 rounds of attacking
triggers: tensor([29376, 27135]) 
the accuracy :0.91776 


100%|██████████| 1250/1250 [03:05<00:00,  6.75it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:06<00:00, 18.66it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]


after 2 rounds of attacking
triggers: tensor([29376, 27135]) 
the accuracy :0.91776 


100%|██████████| 1250/1250 [03:04<00:00,  6.76it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.62it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.63it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.65it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]
100%|██████████| 1250/1250 [01:07<00:00, 18.64it/s]


after 3 rounds of attacking
triggers: tensor([29376, 27135]) 
the accuracy :0.91776 


  7%|▋         | 86/1250 [00:12<02:53,  6.69it/s]


KeyboardInterrupt: 