In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Data Structuring

In [2]:
train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')
test_labels_df = pd.read_csv('data/test_labels.csv')

train_df = train_df[['id', 'comment_text', 'toxic']]
negative_sample_train = train_df[train_df['toxic'] == 0].sample(frac=0.1)
positive_sample_train = train_df[train_df['toxic'] == 1]
train_df = pd.concat([negative_sample_train, positive_sample_train])
test_labels_df = test_labels_df[['id', 'toxic']]

test_df = pd.merge(test_df, test_labels_df, on='id', how='inner')
test_df = test_df[test_df['toxic'] != -1]

In [8]:
class TextDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data['comment_text'].iloc[idx]
        label = self.data['toxic'].iloc[idx]
        inputs = self.tokenizer(text, return_tensors='pt', padding='max_length', max_length=128, truncation=True)
        inputs['label'] = torch.tensor(label)
        return inputs

# Model

In [9]:
class TextClassifier(nn.Module):
    def __init__(self, transformer_model, freeze_transformer=True, normalize_PP_T=True):
        super(TextClassifier, self).__init__()
        self.model = AutoModel.from_pretrained(transformer_model)
        # Freeze the transformer model
        if freeze_transformer:
            for param in self.model.parameters():
                param.requires_grad = False
        self.tokenizer = AutoTokenizer.from_pretrained(transformer_model)
        self.fc = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()
        self.set_PP_T(normalize=normalize_PP_T)

    def set_PP_T(self, normalize=True):
        self.embeddings_matrix = self.get_static_embeddings_matrix()
        self.PP_T = self.embeddings_matrix.T @ self.embeddings_matrix
        if normalize:
            # normalize using frobenius norm
            self.PP_T = self.PP_T / torch.norm(self.PP_T, p='fro')

    def get_static_embeddings_matrix(self):
        return self.model.get_input_embeddings().weight
    
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        cls_token = outputs.last_hidden_state[:, 0]
        cls_token = self.fc(cls_token)
        return self.sigmoid(cls_token)

## Train the Classifier

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextClassifier('bert-base-uncased').to(device)

train_dataset = TextDataset(train_df, model.tokenizer)
test_dataset = TextDataset(test_df, model.tokenizer)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [11]:
for epoch in range(5):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        input_ids = input_ids.squeeze(1)
        attention_mask = batch['attention_mask'].to(device)
        attention_mask = attention_mask.view(input_ids.shape)
        labels = batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs.squeeze(), labels.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Epoch {epoch+1}/{5}, Loss: {epoch_loss/len(train_loader)}')

100%|██████████| 59/59 [00:33<00:00,  1.74it/s]


Epoch 1/5, Loss: 0.6905189267659592


100%|██████████| 59/59 [00:34<00:00,  1.73it/s]


Epoch 2/5, Loss: 0.6834584090669277


100%|██████████| 59/59 [00:34<00:00,  1.71it/s]


Epoch 3/5, Loss: 0.6777997683670561


100%|██████████| 59/59 [00:34<00:00,  1.72it/s]


Epoch 4/5, Loss: 0.6711571216583252


100%|██████████| 59/59 [00:34<00:00,  1.70it/s]

Epoch 5/5, Loss: 0.6653854665109666





## Evaluate Classifier

In [12]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        input_ids = input_ids.squeeze(1)
        attention_mask = batch['attention_mask'].to(device)
        attention_mask = attention_mask.view(input_ids.shape)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        predicted = torch.round(outputs.squeeze())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f'Accuracy: {correct/total}')

100%|██████████| 125/125 [01:18<00:00,  1.60it/s]

Accuracy: 0.7550720560192566





# Run Attack

In [91]:
class TextAdvDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data['comment_text'].iloc[idx]
        label = self.data['toxic'].iloc[idx]
        inputs = self.tokenizer(text, return_tensors='pt')
        inputs['label'] = torch.tensor(label)
        return inputs

In [92]:
adv_test_dataset = TextAdvDataset(test_df, model.tokenizer)
adv_test_loader = DataLoader(adv_test_dataset, batch_size=1, shuffle=False)

In [107]:
class AdvRunner:
    def __init__(self, model, criterion, optimizer, device, alpha=1e-3, suffix_len=20, suffix_char='!'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.alpha = alpha
        self.suffix_len = suffix_len

        self.embeddings_matrix = self.model.embeddings_matrix
        self.PP_T = self.model.PP_T.to(self.device)
        self.tokenizer = self.model.tokenizer
        self.model.eval()

        self.suffix = suffix_char * suffix_len

    # TODO: check if this is the correct projection (for now we ignore it)
    def project(self, inputs_embeds):
        inputs_embeds[:, -(self.suffix_len+1):-1] = torch.einsum(
            'ij,bsj->bsi', self.PP_T, inputs_embeds[:, -(self.suffix_len+1):-1])
        return inputs_embeds

    def decode(self, input_ids, skip_special_tokens=True):
        return self.tokenizer.decode(input_ids, skip_special_tokens=skip_special_tokens)
    
    def FGSM_step(self, inputs):
        # Decode original input text and append the suffix
        original_text = self.decode(inputs['input_ids'][0])
        perturbed_text = original_text + self.suffix

        # Tokenize perturbed text
        tokenized = self.tokenizer(perturbed_text, return_tensors='pt').to(self.device)
        input_ids = tokenized['input_ids']
        attention_mask = tokenized['attention_mask']
        label = inputs['label'].float().to(self.device)

        print(f'Starting Attack: {perturbed_text}')
        print(f'Original Label: {label.item()}')

        # Get original text loss
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        original_text_pred = outputs.item()
        print(f'Original Prediction: {original_text_pred}')
        text_loss = self.criterion(outputs, label.unsqueeze(0).unsqueeze(0))
        print(f'Original text loss: {text_loss.item()}')

        # Get embeddings and enable gradient tracking
        inputs_embeds = self.embeddings_matrix[input_ids].clone().detach().to(self.device)
        inputs_embeds.requires_grad_(True)

        # Forward pass with embeddings
        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        original_emb_pred = outputs.item()
        print(f'Original Embedding Prediction: {original_emb_pred}')
        loss = self.criterion(outputs, label.unsqueeze(0).unsqueeze(0))

        # Compute gradients
        loss.backward()
        print(f'Original embedding loss: {loss.item()}')

        # Extract gradients corresponding to suffix tokens only
        perturbation = inputs_embeds.grad[:, -(self.suffix_len):-1].sign()

        # Apply perturbation
        with torch.no_grad():
            inputs_embeds[:, -(self.suffix_len):-1] += self.alpha * perturbation
            # inputs_embeds = self.project(inputs_embeds)

        # Map perturbed embeddings back to discrete tokens (nearest embeddings)
        with torch.no_grad():
            distances = torch.cdist(inputs_embeds, self.embeddings_matrix)
            perturbed_input_ids = distances.argmin(dim=-1)

        # Decode new perturbed text
        new_text = self.decode(perturbed_input_ids[0])
        print(f'Perturbed Text: {new_text}')

        # Get perturbed text loss
        perturbed_outputs = self.model(input_ids=perturbed_input_ids, attention_mask=attention_mask)
        perturbed_text_pred = perturbed_outputs.item()
        print(f'Perturbed Prediction: {perturbed_text_pred}')
        perturbed_loss = self.criterion(perturbed_outputs, label.unsqueeze(0).unsqueeze(0))
        print(f'Perturbed text loss: {perturbed_loss.item()}')

        # Evaluate perturbed inputs_embeds on the model
        perturbed_emb_outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        perturbed_emb_pred = perturbed_emb_outputs.item()
        print(f'Perturbed Embedding Prediction: {perturbed_emb_pred}')
        perturbed_emb_loss = nn.BCELoss()(perturbed_emb_outputs, label.unsqueeze(0).unsqueeze(0))
        print(f'Perturbed embedding loss: {perturbed_emb_loss.item()}')

        return perturbed_input_ids
    
    def PGD_step(self):
        pass


## Run Attack on Example

In [111]:
# Take example from test dataset
inputs = next(iter(adv_test_loader))
advrunner = AdvRunner(model, criterion, optimizer, device, alpha=1)

single_input = {key: inputs[key][0] for key in inputs.keys()}
advrunner.FGSM_step(single_input)

Starting Attack: thank you for understanding. i think very highly of you and would not revert without discussion.!!!!!!!!!!!!!!!!!!!!
Original Label: 0.0
Original Prediction: 0.5034220814704895
Original text loss: 0.7000148892402649
Original Embedding Prediction: 0.5034220814704895
Original embedding loss: 0.7000148892402649
Perturbed Text: thank you for understanding. i think very highly of you and would not revert without discussion.! instead instead but some some but and and and and and. assert lay of donovan mace escaped an
Perturbed Prediction: 0.5381074547767639
Perturbed text loss: 0.7724230289459229
Perturbed Embedding Prediction: 0.6985443830490112
Perturbed embedding loss: 1.1991324424743652


tensor([[  101,  4067,  2017,  2005,  4824,  1012,  1045,  2228,  2200,  3811,
          1997,  2017,  1998,  2052,  2025,  7065,  8743,  2302,  6594,  1012,
           999,  2612,  2612,  2021,  2070,  2070,  2021,  1998,  1998,  1998,
          1998,  1998,  1012, 20865,  3913,  1997, 12729, 19382,  6376,  2019,
           102]], device='cuda:0')