## Data Load and Process

In [1]:
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

In [2]:
# Set seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# demo tokenization settings
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 32
NAME = "distilbert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_PATH = "data/"
NUM_SAMPLES = -1  # -1 to run on all data


class IMDBDataset:
    """
    Dataset class to handle the IMDB movie reviews dataset.
    """
    def __init__(self, path, split="train", num_samples=-1):
        self.path = path
        self.split = split
        self.num_samples = num_samples

        # Read the CSV file. Assume the format has 'review' and 'sentiment' columns

        if split == "train":
            self.df = pd.read_csv(f"{self.path}/imdb_train.csv")
        elif split == "test":
            self.df = pd.read_csv(f"{self.path}/imdb_test.csv")
        else:
            self.df = pd.read_csv(f"{self.path}/imdb_validation.csv")

        if num_samples != -1:
            self.df = self.df.sample(n=num_samples, random_state=SEED)  # Sample a subset if required

        self.df.reset_index(drop=True, inplace=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return {
            "review": self.df.iloc[idx]['review'],
            "sentiment": self.df.iloc[idx]['sentiment'],
        }

In [3]:
class IMDBDataProcessor:
    """
    Data processor class for the IMDB dataset to tokenize and create dataloaders.
    """
    def __init__(self, name=NAME, max_seq_length=MAX_SEQUENCE_LENGTH, batch_size=BATCH_SIZE):
        self.tokenizer = AutoTokenizer.from_pretrained(name)
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size

    def tokenize(self, texts):
        """
        Tokenize the input texts for BERT model.
        """
        return self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_seq_length, return_tensors="pt")

    def create_dataloader(self, dataset, shuffle=False):
        """
        Create a PyTorch DataLoader.
        """
        texts = dataset.df['review'].tolist()
        sentiments = dataset.df['sentiment'].tolist()

        # Tokenize the texts
        encoding = self.tokenize(texts)

        # Convert the sentiments to tensors
        labels = torch.tensor(sentiments)

        # Create a dataset
        dataset = torch.utils.data.TensorDataset(encoding['input_ids'], encoding['attention_mask'], labels)

        # Create a dataloader
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle)

    def process(self, data_path):
        """
        Process the IMDB dataset and create dataloaders for train and test sets.
        """
        train_dataset = IMDBDataset(path=data_path, split="train")
        test_dataset = IMDBDataset(path=data_path, split="test")
        valid_dataset = IMDBDataset(path=data_path, split="valid")

        # Create the dataloaders
        train_dataloader = self.create_dataloader(train_dataset, shuffle=True)
        test_dataloader = self.create_dataloader(test_dataset, shuffle=False)
        valid_dataloader = self.create_dataloader(valid_dataset, shuffle=False)

        # Return the dataloaders
        return train_dataloader, test_dataloader, valid_dataloader

In [4]:
data_processor = IMDBDataProcessor()
train_dataloader, test_dataloader, valid_dataloader = data_processor.process(DATA_PATH)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [5]:
input_ids, attention_mask, labels = next(iter(train_dataloader))

## Bert Model

In [6]:
import os

import torch
import torch.nn as nn
from transformers import AutoModel, logging

logging.set_verbosity_error()

# model settings
NUM_LABELS = 2
BERT_ENCODER_OUTPUT_SIZE = 768
CLF_LAYER_1_DIM = 64
CLF_DROPOUT_PROB = 0.4
MODE = "fine-tune"  # pre-train
NAME = "distilbert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BertClassifier(nn.Module):
    def __init__(self, name=NAME, mode=MODE, pretrained_checkpoint=None):
        super(BertClassifier, self).__init__()
        self.mode = mode
        D_in, H, D_out = BERT_ENCODER_OUTPUT_SIZE, CLF_LAYER_1_DIM, NUM_LABELS
        if pretrained_checkpoint is None:
            self.bert = AutoModel.from_pretrained(NAME)
        else:
            state_dict = torch.load(pretrained_checkpoint, map_location=device)
            self.bert = AutoModel.from_pretrained(
                NAME, state_dict={k: v for k, v in state_dict.items() if "bert" in k}
            )

        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            nn.Dropout(CLF_DROPOUT_PROB),
            nn.Linear(H, D_out),
        )

        if self.mode == "pre-train":
            freeze_bert = True
        else:
            freeze_bert = False

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state_cls = outputs[0][:, 0, :]
        logits = self.classifier(last_hidden_state_cls)
        return logits





In [7]:
model = BertClassifier()
model = model.to(device)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

## Train and Evaluate Model

In [8]:
import random
import time
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

In [9]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [10]:
# paths for IO
MODELS_PATH = "../models"
OUTPUTS_PATH = "../outputs"
LOGS_PATH = "../logs"

NAME = "distilbert-base-uncased"
DATASET_NAME = "imdb"


NUM_EPOCHS = 4
LEARNING_RATE = 5e-6
EPS = 1e-8

In [11]:
def train_model(
    model,
    train_dataloader,
    valid_dataloader,
    criterion,
    optimizer,
    scheduler,
    num_epochs,
    models_path,
    save_intermediate=False,
):
    model_name = f"{NAME.replace('/', '-')}_model"
    writer = SummaryWriter(log_dir=LOGS_PATH)
    for epoch in tqdm(range(num_epochs), desc="Epochs", unit="epoch", total=num_epochs):
        train_loss, valid_loss = 0, 0
        train_acc, valid_acc = 0, 0

        model.train()

        for i, data in tqdm(
            enumerate(train_dataloader),
            desc="Batches",
            unit="batch",
            total=len(train_dataloader),
        ):
            input_ids, attention_mask, labels = data

            optimizer.zero_grad()

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_acc += (predicted == labels).sum().item()

        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader.dataset)

        writer.add_scalar("Train Loss", train_loss, epoch)
        writer.add_scalar("Train Accuracy", train_acc, epoch)

        with torch.no_grad():
            model.eval()
            for i, data in enumerate(valid_dataloader):
                input_ids, attention_mask, labels = data
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                valid_acc += (predicted == labels).sum().item()

            valid_loss /= len(valid_dataloader)
            valid_acc /= len(valid_dataloader.dataset)

            writer.add_scalar("Validation Loss", valid_loss, epoch)
            writer.add_scalar("Validation Accuracy", valid_acc, epoch)

        print(
            f"Epoch: {epoch+1} | "
            f"Train Loss: {train_loss:.3f} | "
            f"Train Accuracy: {train_acc*100:.2f}% | "
            f"Validation Loss: {valid_loss:.3f} | "
            f"Validation Accuracy: {valid_acc*100:.2f}%"
        )

        if save_intermediate:
            # save intermediate models after each epoch if needed
            filename = DATASET_NAME
            filename += datetime.now().strftime(
                f"_%d-%m-%y-%H_%M_{MODE}_{model_name}_epoch{epoch}.pt"
            )
            torch.save(model.state_dict(), f"{models_path}/{filename}")

    filename = DATASET_NAME
    filename += datetime.now().strftime(f"_%d-%m-%y-%H_%M_{MODE}_{model_name}_final.pt")
    torch.save(model.state_dict(), f"{models_path}/{filename}")
    writer.close()

In [12]:
def evaluate_model(model, dataloader, split):
    model.eval()
    test_acc = 0
    batch_count = 0
    all_texts, all_labels, all_preds = [], [], []
    tokenizer = AutoTokenizer.from_pretrained(NAME)
    for i, data in enumerate(dataloader):
        input_ids, attention_mask, labels = data
        all_labels.append(labels.cpu().numpy())
        all_texts.append(tokenizer.batch_decode(input_ids, skip_special_tokens=True))

        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds.cpu().numpy())
            test_acc += (preds == labels).sum().item()
            batch_count += 1

    test_acc /= batch_count * dataloader.batch_size
    if split == "test":
        print(f"Test Accuracy: {test_acc*100:.2f}% \n")
    elif split == "valid":
        print(f"Validation Accuracy: {test_acc*100:.2f}% \n")
    return all_texts, all_labels, all_preds

In [13]:
def save_test_as_dataframe(all_texts, all_labels, all_preds, split):
    labels_df = pd.DataFrame(
        {
            "content": [text for batch in all_texts for text in batch],
            "true_labels": [label for batch in all_labels for label in batch],
            "predicted_labels": [pred for batch in all_preds for pred in batch],
        }
    )
    print(labels_df.head())
    if NUM_SAMPLES == -1:
        sample_size = "all"
    else:
        sample_size = NUM_SAMPLES

    filename = DATASET_NAME
    if split == "test":
        filename += datetime.now().strftime(
            f"_%d-%m-%y-%H_%M_{MODE}_{NAME.replace('/', '-')}_test_results_{sample_size}.csv"
        )
    elif split == "valid":
        filename += datetime.now().strftime(
            f"_%d-%m-%y-%H_%M_{MODE}_{NAME.replace('/', '-')}_valid_results_{sample_size}.csv"
        )
    labels_df.to_csv(f"{OUTPUTS_PATH}/{filename}", index=False)

In [14]:
criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) * NUM_EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [15]:
start_time = time.time()
print("Starting train/val loop now : \n")
train_model(
    model,
    train_dataloader,
    test_dataloader,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    MODELS_PATH,
    save_intermediate=False,
)
end_time = time.time()
time_elapsed = end_time - start_time
print(f"Training Complete : {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s")


Starting train/val loop now : 



Epochs:   0%|          | 0/4 [00:00<?, ?epoch/s]
Batches:   0%|          | 0/19 [00:00<?, ?batch/s][A
Batches:   5%|▌         | 1/19 [00:33<10:10, 33.92s/batch][A
Batches:  11%|█         | 2/19 [01:03<08:49, 31.17s/batch][A
Batches:  16%|█▌        | 3/19 [01:33<08:10, 30.68s/batch][A
Batches:  21%|██        | 4/19 [02:02<07:33, 30.23s/batch][A
Batches:  26%|██▋       | 5/19 [02:32<07:00, 30.04s/batch][A
Batches:  32%|███▏      | 6/19 [03:01<06:26, 29.74s/batch][A
Batches:  37%|███▋      | 7/19 [03:31<05:57, 29.81s/batch][A
Batches:  42%|████▏     | 8/19 [04:01<05:26, 29.70s/batch][A
Batches:  47%|████▋     | 9/19 [04:40<05:26, 32.64s/batch][A
Batches:  53%|█████▎    | 10/19 [05:11<04:49, 32.13s/batch][A
Batches:  58%|█████▊    | 11/19 [05:40<04:11, 31.40s/batch][A
Batches:  63%|██████▎   | 12/19 [06:11<03:37, 31.02s/batch][A
Batches:  68%|██████▊   | 13/19 [06:42<03:06, 31.07s/batch][A
Batches:  74%|███████▎  | 14/19 [07:12<02:34, 30.92s/batch][A
Batches:  79%|███████▉  

Epoch: 1 | Train Loss: 0.693 | Train Accuracy: 51.83% | Validation Loss: 0.690 | Validation Accuracy: 60.50%



Batches:   0%|          | 0/19 [00:00<?, ?batch/s][A
Batches:   5%|▌         | 1/19 [00:30<09:05, 30.32s/batch][A
Batches:  11%|█         | 2/19 [01:00<08:37, 30.44s/batch][A
Batches:  16%|█▌        | 3/19 [01:30<07:59, 29.99s/batch][A
Batches:  21%|██        | 4/19 [02:00<07:29, 29.96s/batch][A
Batches:  26%|██▋       | 5/19 [02:30<06:59, 29.96s/batch][A
Batches:  32%|███▏      | 6/19 [03:00<06:29, 29.99s/batch][A
Batches:  37%|███▋      | 7/19 [03:29<05:58, 29.83s/batch][A
Batches:  42%|████▏     | 8/19 [04:00<05:29, 29.99s/batch][A
Batches:  47%|████▋     | 9/19 [04:29<04:58, 29.88s/batch][A
Batches:  53%|█████▎    | 10/19 [04:59<04:29, 29.91s/batch][A
Batches:  58%|█████▊    | 11/19 [05:29<03:58, 29.81s/batch][A
Batches:  63%|██████▎   | 12/19 [05:58<03:28, 29.73s/batch][A
Batches:  68%|██████▊   | 13/19 [06:28<02:58, 29.78s/batch][A
Batches:  74%|███████▎  | 14/19 [06:57<02:28, 29.62s/batch][A
Batches:  79%|███████▉  | 15/19 [07:28<01:59, 29.89s/batch][A
Batches: 

Epoch: 2 | Train Loss: 0.688 | Train Accuracy: 53.50% | Validation Loss: 0.683 | Validation Accuracy: 66.50%



Batches:   0%|          | 0/19 [00:00<?, ?batch/s][A
Batches:   5%|▌         | 1/19 [00:31<09:21, 31.19s/batch][A
Batches:  11%|█         | 2/19 [01:02<08:53, 31.40s/batch][A
Batches:  16%|█▌        | 3/19 [01:34<08:28, 31.78s/batch][A
Batches:  21%|██        | 4/19 [02:04<07:45, 31.01s/batch][A
Batches:  26%|██▋       | 5/19 [02:34<07:09, 30.69s/batch][A
Batches:  32%|███▏      | 6/19 [03:05<06:36, 30.52s/batch][A
Batches:  37%|███▋      | 7/19 [03:35<06:06, 30.53s/batch][A
Batches:  42%|████▏     | 8/19 [04:05<05:34, 30.43s/batch][A
Batches:  47%|████▋     | 9/19 [04:35<05:01, 30.11s/batch][A
Batches:  53%|█████▎    | 10/19 [05:05<04:32, 30.23s/batch][A
Batches:  58%|█████▊    | 11/19 [05:35<04:00, 30.12s/batch][A
Batches:  63%|██████▎   | 12/19 [06:05<03:30, 30.12s/batch][A
Batches:  68%|██████▊   | 13/19 [06:35<03:00, 30.01s/batch][A
Batches:  74%|███████▎  | 14/19 [07:05<02:29, 30.00s/batch][A
Batches:  79%|███████▉  | 15/19 [07:35<01:59, 29.85s/batch][A
Batches: 

Epoch: 3 | Train Loss: 0.681 | Train Accuracy: 60.50% | Validation Loss: 0.677 | Validation Accuracy: 69.50%



Batches:   0%|          | 0/19 [00:00<?, ?batch/s][A
Batches:   5%|▌         | 1/19 [00:29<08:44, 29.16s/batch][A
Batches:  11%|█         | 2/19 [00:58<08:21, 29.48s/batch][A
Batches:  16%|█▌        | 3/19 [01:28<07:54, 29.64s/batch][A
Batches:  21%|██        | 4/19 [01:58<07:27, 29.84s/batch][A
Batches:  26%|██▋       | 5/19 [02:28<06:58, 29.88s/batch][A
Batches:  32%|███▏      | 6/19 [02:58<06:29, 29.95s/batch][A
Batches:  37%|███▋      | 7/19 [03:28<05:58, 29.87s/batch][A
Batches:  42%|████▏     | 8/19 [04:05<05:54, 32.26s/batch][A
Batches:  47%|████▋     | 9/19 [04:40<05:29, 32.95s/batch][A
Batches:  53%|█████▎    | 10/19 [05:11<04:50, 32.30s/batch][A
Batches:  58%|█████▊    | 11/19 [05:40<04:11, 31.50s/batch][A
Batches:  63%|██████▎   | 12/19 [06:11<03:37, 31.11s/batch][A
Batches:  68%|██████▊   | 13/19 [06:40<03:04, 30.69s/batch][A
Batches:  74%|███████▎  | 14/19 [07:11<02:32, 30.58s/batch][A
Batches:  79%|███████▉  | 15/19 [07:40<02:01, 30.25s/batch][A
Batches: 

Epoch: 4 | Train Loss: 0.676 | Train Accuracy: 62.00% | Validation Loss: 0.674 | Validation Accuracy: 70.50%





Training Complete : 42m 17s


In [16]:
print("Model Evaluation : \n")
all_texts_test, all_labels_test, all_preds_test = evaluate_model(model, test_dataloader, split="test")

Model Evaluation : 

Test Accuracy: 62.95% 



In [17]:
save_test_as_dataframe(all_texts_test, all_labels_test, all_preds_test, split="test")

                                             content  true_labels  \
0  mexican werewolf in texas is set in the small ...            0   
1  although promoted as one of the most sincere t...            0   
2  such a joyous world has been created for us in...            1   
3  don't torture a duckling is one of fulci's ear...            1   
4  this movie really woke me up, like it wakes up...            1   

   predicted_labels  
0                 0  
1                 0  
2                 1  
3                 1  
4                 1  


In [18]:
all_texts_valid, all_labels_valid, all_preds_valid = evaluate_model(
    model, valid_dataloader, split="valid"
)
save_test_as_dataframe(
    all_texts_valid, all_labels_valid, all_preds_valid, split="valid"
)

Validation Accuracy: 69.64% 

                                             content  true_labels  \
0  seven ups has been compared to bullitt for the...            1   
1  crossfire remains one of the best hollywood me...            1   
2  i saw the movie " hoot " and then i immediatel...            1   
3  ye lou's film purple butterfly pits a secret o...            0   
4  apparently, the mutilation man is about a guy ...            0   

   predicted_labels  
0                 0  
1                 1  
2                 1  
3                 1  
4                 0  


## SHAPIPY

In [19]:
!pip install shap

Collecting shap
  Downloading shap-0.45.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (538 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m538.2/538.2 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting slicer==0.0.7 (from shap)
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.45.0 slicer-0.0.7


In [20]:
!pip install swifter

Collecting swifter
  Downloading swifter-1.4.0.tar.gz (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: swifter
  Building wheel for swifter (setup.py) ... [?25l[?25hdone
  Created wheel for swifter: filename=swifter-1.4.0-py3-none-any.whl size=16507 sha256=1b25d40e0575f34395df68ef3e29f9fd14a3e35b0a9e2a48cb6ddb3c3e8ff63a
  Stored in directory: /root/.cache/pip/wheels/e4/cf/51/0904952972ee2c7aa3709437065278dc534ec1b8d2ad41b443
Successfully built swifter
Installing collected packages: swifter
Successfully installed swifter-1.4.0


In [21]:
import random
import re

import nltk  # for word tokenization during preprocessing
import numpy as np
import pandas as pd
import shap
import spacy  # for NER
import swifter
import torch
from nltk.corpus import stopwords
from tqdm import tqdm
from transformers import AutoTokenizer

In [23]:
# Arguments
DATASET_NAME = "imdb"

# Tokenizer and Model settings
MAX_SEQUENCE_LENGTH = 128
NAME = "distilbert-base-uncased"
MODE = "fine-tune"

# Model checkpoint
MODEL_CKPT = "../models/imdb_26-04-24-13_55_fine-tune_distilbert-base-uncased_model_final.pt"

# Data path - save results in input df
TEST_PATH = "../outputs/imdb_26-04-24-13_56_fine-tune_distilbert-base-uncased_test_results_all.csv"
VALID_PATH = "../outputs/imdb_26-04-24-13_57_fine-tune_distilbert-base-uncased_valid_results_all.csv"

In [24]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [25]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [26]:
# NLTK stopwords and spacy NER
stop_words = set(stopwords.words("english"))
nlp = spacy.load("en_core_web_sm")

In [27]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(NAME, use_fast=True)
model = BertClassifier(name=NAME, mode=MODE, pretrained_checkpoint=None)
model.load_state_dict(torch.load(MODEL_CKPT, map_location=device))
model.eval()
model.to(device)

BertClassifier(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(i

In [28]:
def tokenize(tokenizer, sentences, padding="max_length"):
    encoded = tokenizer.batch_encode_plus(
        sentences, max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding=padding
    )
    input_ids = encoded["input_ids"]
    attention_mask = encoded["attention_mask"]
    return torch.tensor(input_ids).to(device), torch.tensor(attention_mask).to(device)


def get_model_output(sentences):
    sentences = list(sentences)
    input_ids, attention_mask = tokenize(tokenizer, sentences)
    with torch.no_grad():
        output = model(input_ids, attention_mask)
        probabilities = torch.softmax(output, dim=-1)
    return probabilities.cpu().numpy()


def preprocess_text(text):
    text = re.sub(r"#", "", text.lower())
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
    text = " ".join(tokens)
    return text


def shapper(sentence, output_class):
    explainer = shap.Explainer(
        lambda x: get_model_output(x),
        shap.maskers.Text(tokenizer),
        silent=False,
    )
    shap_values = explainer([sentence])
    importance_values = shap_values[:, :, output_class].values
    tokenized_sentence = tokenizer.tokenize(sentence)
    token_importance = list(zip(tokenized_sentence, importance_values[0]))

    # Perform NER
    doc = nlp(sentence)

    # Aggregate salience scores for named entities
    aggregated_token_importance = []
    token_scores = {
        token: score for token, score in token_importance if not token.startswith("##")
    }

    token_scores_aggregated = token_scores.copy()

    for ent in doc.ents:
        scores = [token_scores.get(token, 0) for token in ent.text.split()]
        aggregated_score = sum(scores)
        average_score = aggregated_score / len(scores) if scores else 0
        aggregated_token_importance.append((ent.text, average_score))

        for token in ent.text.split():
            if token in token_scores_aggregated:
                del token_scores_aggregated[token]

    for token, score in token_scores_aggregated.items():
        aggregated_token_importance.append((token, score))

    # Split positive and negative scores
    shap_neg_outs = [item for item in aggregated_token_importance if item[1] < 0]
    # sort by score - largest mod value first
    shap_neg_outs = sorted(shap_neg_outs, key=lambda x: x[1])

    shap_pos_outs = [item for item in aggregated_token_importance if item[1] > 0]
    # sort by score - largest value first
    shap_pos_outs = sorted(shap_pos_outs, key=lambda x: x[1], reverse=True)
    return shap_neg_outs, shap_pos_outs



In [29]:
df = pd.read_csv(TEST_PATH)
df = df.dropna(subset=["content"])

df["processed_content"] = df["content"].apply(preprocess_text)
print(f"Running shap on {len(df)} samples...")
print("-" * 80)

df[["shap_neg_outs", "shap_pos_outs"]] = df.swifter.apply(
    lambda row: pd.Series(
        shapper(row["processed_content"], row["predicted_labels"])
    ),
    axis=1,
)

print(df.head())
print("-" * 80)
# save df to csv at TEST_PATH
df.to_csv(TEST_PATH, index=False)


Running shap on 200 samples...
--------------------------------------------------------------------------------


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.53s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.21s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.47s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.02s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.49s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.63s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.26s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.78s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.00s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.17s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.66s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.75s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.99s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.59s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.05s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.01s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.55s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.77s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.43s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.26s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.23s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.34s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.11s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.46s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.11s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.56s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.06s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.81s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.77s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.44s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.08s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.68s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.71s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.23s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.19s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.59s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.98s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.97s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.95s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.79s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.01s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.61s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.06s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.96s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.42s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.19s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.33s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.70s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.52s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.23s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.88s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.16s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.85s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.71s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.29s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.50s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.78s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.27s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.47s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.43s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.78s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.08s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.11s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.57s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.86s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.63s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.22s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.96s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.01s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.50s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.18s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.63s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:34, 34.16s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.17s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.46s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.05s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:33, 33.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.02s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.81s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.81s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.02s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 31.00s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.46s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.39s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.94s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.16s/it]               


  0%|          | 0/462 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:22, 22.50s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.25s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.24s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.63s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.90s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.53s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.54s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.95s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.85s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.27s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.01s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.68s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.34s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.17s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.90s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.31s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.81s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.70s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.61s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.94s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.93s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.38s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.57s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.26s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.68s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.19s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.51s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.62s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.26s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.49s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.45s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.62s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.39s/it]               

                                             content  true_labels  \
0  mexican werewolf in texas is set in the small ...            0   
1  although promoted as one of the most sincere t...            0   
2  such a joyous world has been created for us in...            1   
3  don't torture a duckling is one of fulci's ear...            1   
4  this movie really woke me up, like it wakes up...            1   

   predicted_labels                                  processed_content  \
0                 0  mexican werewolf texas set small border town f...   
1                 0  although promoted one sincere turkish films am...   
2                 1  joyous world created us pixar bug life immerse...   
3                 1  torture duckling one fulci earlier honestly te...   
4                 1  movie really woke like wakes main male charact...   

                                       shap_neg_outs  \
0  [(legends, -0.0005464410938906904), (speak, -0...   
1  [(films, -0.0084853386506




In [30]:
df = pd.read_csv(VALID_PATH)
df = df.dropna(subset=["content"])

df["processed_content"] = df["content"].apply(preprocess_text)
print(f"Running shap on {len(df)} samples...")
print("-" * 80)

df[["shap_neg_outs", "shap_pos_outs"]] = df.swifter.apply(
    lambda row: pd.Series(
        shapper(row["processed_content"], row["predicted_labels"])
    ),
    axis=1,
)

print(df.head())
print("-" * 80)
# save df to csv at VALID_PATH
df.to_csv(VALID_PATH, index=False)

Running shap on 200 samples...
--------------------------------------------------------------------------------


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.18s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.85s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.06s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.02s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.24s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.49s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.29s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.79s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.09s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.42s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.34s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.45s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.24s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.79s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.53s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.34s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.91s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.71s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.53s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.51s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.03s/it]               


  0%|          | 0/462 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:24, 24.91s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.90s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.84s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.61s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.24s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.09s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.66s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.74s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.75s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.48s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.35s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.47s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.75s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.03s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.72s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.69s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.31s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 30.00s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.35s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.11s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.54s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.54s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.03s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.82s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.80s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.01s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.51s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.26s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.93s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.27s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.75s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.25s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.99s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.77s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:25, 25.08s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.80s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.62s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.21s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.91s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.64s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.79s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.30s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:32, 32.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.33s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.38s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.95s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.92s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.78s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.97s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.58s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.51s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.62s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.39s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.06s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.08s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.31s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.75s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.70s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.35s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.22s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.90s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.21s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.53s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.21s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.94s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.86s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.85s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.47s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.13s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.94s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.46s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.15s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.27s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.28s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.91s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.65s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.29s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.52s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.12s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.89s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.41s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.06s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.41s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.34s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.07s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.50s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.99s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.97s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.97s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.97s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.70s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.52s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.36s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.86s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.69s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.72s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.80s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.27s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.48s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.10s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.04s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.50s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.98s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.37s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.45s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.76s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.20s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.14s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.87s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.29s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.60s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.63s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.03s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.51s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.19s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.79s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.73s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.32s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.40s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.56s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.77s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.38s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.38s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.39s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:30, 30.25s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:29, 29.57s/it]               


  0%|          | 0/498 [00:00<?, ?it/s]


PartitionExplainer explainer: 2it [00:31, 31.02s/it]               

                                             content  true_labels  \
0  seven ups has been compared to bullitt for the...            1   
1  crossfire remains one of the best hollywood me...            1   
2  i saw the movie " hoot " and then i immediatel...            1   
3  ye lou's film purple butterfly pits a secret o...            0   
4  apparently, the mutilation man is about a guy ...            0   

   predicted_labels                                  processed_content  \
0                 0  seven ups compared bullitt chase scene come an...   
1                 1  crossfire remains one best hollywood message m...   
2                 1  saw movie hoot immediately decided comment tru...   
3                 1  ye lou film purple butterfly pits secret organ...   
4                 0  apparently mutilation man guy wanders land per...   

                                       shap_neg_outs  \
0  [(builds builds, -0.00022008394907143983), (be...   
1  [(little, -0.000958580507




## Attack Prep

In [31]:
import ast
import calendar
import json
import random

import numpy as np
import pandas as pd
import spacy
import swifter
from nltk.corpus import words
from tqdm import tqdm


In [33]:
!python -m spacy download en_core_web_lg

Collecting en-core-web-lg==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.7.1/en_core_web_lg-3.7.1-py3-none-any.whl (587.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m587.7/587.7 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: en-core-web-lg
Successfully installed en-core-web-lg-3.7.1
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_lg')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [34]:
nlp = spacy.load("en_core_web_lg", disable=["parser", "ner"])

In [35]:
# Arguments
NER = "default"  # "default", "other"
CANDIDATES_COUNT = 100  # 25, 50, 100, 150, 200
IMPORTANT_WORDS_COUNT = 20  # 5, 10, 20, 30, 40

In [36]:
def compute_attack_candidates(row, candidates, labels, shap_score):
    if row["predicted_labels"] == labels:
        for word, score in ast.literal_eval(row[shap_score]):
            if word not in candidates:
                candidates[word] = abs(score)
            else:
                if score > candidates[word]:
                    candidates[word] = abs(score)

In [37]:
nltk.download('words')

[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Unzipping corpora/words.zip.


True

In [38]:
# Define filters
filters = set(calendar.day_name).union(set(calendar.month_name)).union({""})
# also include day names in lower case
filters = filters.union({day.lower() for day in calendar.day_name}).union(
    {month.lower() for month in calendar.month_name}
)
# also include short forms of day names and months
filters = filters.union({day[:3].lower() for day in calendar.day_name}).union(
    {month[:3].lower() for month in calendar.month_name}
)
# Convert list of words to set for faster lookup
english_words = set(words.words())

In [39]:
def is_valid(word):
    if word in filters:
        return False
    if len(word) <= 3:
        if word in ["cnn"]:
            return True
        return False
    if word in english_words:
        return False
    return True


def filter_candidates(candidates):
    df = pd.Series(candidates).to_frame().reset_index()
    df.columns = ["word", "count"]
    df["is_valid"] = df["word"].swifter.apply(is_valid)
    valid_candidates = df[df["is_valid"]].set_index("word")["count"].to_dict()
    return valid_candidates


def extract_pos(words):
    noun_dict = {}
    verb_dict = {}
    adj_adv_dict = {}

    for word, score in words.items():
        # get the pos tag of each word
        token = nlp(word)[0]
        if token.pos_ == "NOUN" or token.pos_ == "PROPN":
            noun_dict[word] = score
        elif token.pos_ == "VERB":
            verb_dict[word] = score
        elif token.pos_ == "ADJ" or token.pos_ == "ADV":
            adj_adv_dict[word] = score

    return noun_dict, verb_dict, adj_adv_dict


In [41]:
df = pd.read_csv(VALID_PATH)
# real attack candidates = (pos scores from 0) + (neg scores from 1)
# we want true label real to be predicted as fake
real_attack_candidates = {}
df.swifter.apply(
    compute_attack_candidates,
    args=(real_attack_candidates, 0, "shap_pos_outs"),
    axis=1,
)
df.swifter.apply(
    compute_attack_candidates,
    args=(real_attack_candidates, 1, "shap_neg_outs"),
    axis=1,
)
# sort and store by abs shap score
real_attack_candidates = {
    k: v
    for k, v in sorted(
        real_attack_candidates.items(), key=lambda item: item[1], reverse=True
    )
}

# fake attack candidates = (pos scores from 1) + (neg scores from 0)
# we want true label fake to be predicted as real
fake_attack_candidates = {}
df.swifter.apply(
    compute_attack_candidates,
    args=(fake_attack_candidates, 1, "shap_pos_outs"),
    axis=1,
)
df.swifter.apply(
    compute_attack_candidates,
    args=(fake_attack_candidates, 0, "shap_neg_outs"),
    axis=1,
)
# sort and store by abs shap score
fake_attack_candidates = {
    k: v
    for k, v in sorted(
        fake_attack_candidates.items(), key=lambda item: item[1], reverse=True
    )
}

# Assuming 'fake_attack_candidates' and 'real_attack_candidates' are your input dictionaries
fake_attack_candidates_filtered = filter_candidates(fake_attack_candidates)
real_attack_candidates_filtered = filter_candidates(real_attack_candidates)


# store first 100 candidates of each
fake_attack_candidates = {
    k: v
    for k, v in list(fake_attack_candidates_filtered.items())[
        :CANDIDATES_COUNT
    ]
}
real_attack_candidates = {
    k: v
    for k, v in list(real_attack_candidates_filtered.items())[
        :CANDIDATES_COUNT
    ]
}

# remove the symbol "ƒ†" and "Ġ" from each word
fake_attack_candidates = {
    word.replace("ƒ†", "").replace("Ġ", ""): score
    for word, score in fake_attack_candidates.items()
}
real_attack_candidates = {
    word.replace("ƒ†", "").replace("Ġ", ""): score
    for word, score in real_attack_candidates.items()
}

with open(
    f"./outputs/shap_outputs/{DATASET_NAME}_fake_attack_candidates.json",
    "w",
) as f:
    json.dump(fake_attack_candidates, f)
with open(
    f"./outputs/shap_outputs/{DATASET_NAME}_real_attack_candidates.json",
    "w",
) as f:
    json.dump(real_attack_candidates, f)

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/200 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2900 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/3962 [00:00<?, ?it/s]

In [42]:
df = pd.read_csv(TEST_PATH)
df["important_words"] = None

# storing switching tokens
for idx, row in tqdm(
    df.iterrows(), total=len(df), desc="Storing switching tokens"
):
    pos_scores = ast.literal_eval(
        row["shap_pos_outs"]
    )  # already sorted in order of abs score
    neg_scores = ast.literal_eval(
        row["shap_neg_outs"]
    )  # already sorted in order of abs score

    # apply ths is_valid function to filter out switching tokens
    pos_scores = [(word, score) for word, score in pos_scores if is_valid(word)]
    neg_scores = [(word, score) for word, score in neg_scores if is_valid(word)]

    # keep unique words in each list
    pos_scores = list(dict.fromkeys(pos_scores))
    neg_scores = list(dict.fromkeys(neg_scores))

    # fake attack
    if row["true_labels"] == 0:
        if row["predicted_labels"] == 0:
            # need to reverse neg scores
            neg_scores.reverse()

            important_words = []
            for word, score in pos_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break
            for word, score in neg_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

        elif row["predicted_labels"] == 1:
            # need to reverse pos scores
            pos_scores.reverse()

            important_words = []
            for word, score in neg_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

            for word, score in pos_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

    elif row["true_labels"] == 1:
        if row["predicted_labels"] == 0:
            # need to reverse pos scores
            pos_scores.reverse()

            important_words = []
            for word, score in neg_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

            for word, score in pos_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

        elif row["predicted_labels"] == 1:
            # need to reverse neg scores
            neg_scores.reverse()

            important_words = []
            for word, score in pos_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break
            for word, score in neg_scores:
                if len(important_words) < IMPORTANT_WORDS_COUNT:
                    important_words.append(word)
                else:
                    break

    # some sentences may have less than specified number of important words
    # use all tokens in the sentence if this is the case
    if len(important_words) < IMPORTANT_WORDS_COUNT:
        important_words = str(row["content"]).split()
        if len(important_words) > IMPORTANT_WORDS_COUNT:
            important_words = important_words[:IMPORTANT_WORDS_COUNT]
    if len(important_words) > IMPORTANT_WORDS_COUNT:
        important_words = important_words[:IMPORTANT_WORDS_COUNT]

    df.at[idx, "important_words"] = important_words
    # remove the symbol "ƒ†" and "Ġ" from each word in the list of important
    # words - weirdly only happened only for roberta preds, roberta shap
    df.at[idx, "important_words"] = [
        word.replace("ƒ†", "").replace("Ġ", "") for word in important_words
    ]

df.to_csv(TEST_PATH, index=False)

Storing switching tokens: 100%|██████████| 200/200 [00:00<00:00, 987.11it/s] 


## Attack

In [84]:
import ast
import json
import os
import random
import re
from concurrent.futures import ThreadPoolExecutor, as_completed

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import swifter
import torch
import yaml
from nltk.tokenize import word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    logging,
)

In [44]:
logging.set_verbosity_error()

In [90]:
METHOD = "salience"
DATASET_NAME = "imdb"
INSERT_POSITION = "random"
FALSE_CATEGORY = "fp"
BATCH_SIZE = 16

DATA_PATH = "./outputs/imdb_26-04-24-13_56_fine-tune_distilbert-base-uncased_test_results_all.csv"
DATA_FOLDER = "./data/attack_files"
NER = "default"

TOKENS_JSON_PATH = f"./outputs/shap_outputs/{DATASET_NAME}_fake_attack_candidates.json"
OUTPUT_PATH = f"{DATA_FOLDER}/{DATASET_NAME}/{INSERT_POSITION}_{FALSE_CATEGORY}_inject_test_data_{METHOD}.csv"

In [91]:
# Constants
LABEL_TO_FILTER = (
    0 if FALSE_CATEGORY == "fp" else 1
)  # Filter news items (0 for fp, 1 for fn)
NUM_WORDS_TO_INJECT = 10  # Number of words to inject in each article
RATIO_TO_MODIFY = 1.0  # Randomly select x% of the news items to modify
NUMBER_OF_CANDIDATE_TOKENS = 200  # chosen from TOKENS_JSON_PATH

In [72]:
# Adverbs
BOOSTER_DICT = [
    "absolutely",
    "amazingly",
    "awfully",
    "barely",
    "completely",
    "considerably",
    "decidedly",
    "deeply",
    "enormously",
    "entirely",
    "especially",
    "exceptionally",
    "exclusively",
    "extremely",
    "fully",
    "greatly",
    "hardly",
    "hella",
    "highly",
    "hugely",
    "incredibly",
    "intensely",
    "majorly",
    "overwhelmingly",
    "really",
    "remarkably",
    "substantially",
    "thoroughly",
    "totally",
    "tremendously",
    "unbelievably",
    "unusually",
    "utterly",
    "very",
]

# Negation dictionary
negate_dict = {
    "isn't": "is",
    "isn't": "is",
    "is not ": "is ",
    "is ": "is not ",
    "didn't": "did",
    "didn't": "did",
    "did not ": "did",
    "does not have": "has",
    "doesn't have": "has",
    "doesn't have": "has",
    "has ": "does not have ",
    "shouldn't": "should",
    "shouldn't": "should",
    "should not": "should",
    "should": "should not",
    "wouldn't": "would",
    "wouldn't": "would",
    "would not": "would",
    "would": "would not",
    "mustn't": "must",
    "mustn't": "must",
    "must not": "must",
    "must ": "must not ",
    "can't": "can",
    "can't": "can",
    "cannot": "can",
    " can ": " cannot ",
}

IRREGULAR_ES_VERB_ENDINGS = ["ss", "x", "ch", "sh", "o"]

# Cache for embeddings (stss)
embedding_cache = {}


In [73]:
# STSS
def switch_words_stss(row, candidate_tokens):
    sentence = row["content"]
    important_words = row["important_words"]

    if not important_words:
        return sentence

    switched_sentence = sentence
    all_words = important_words + candidate_tokens

    # Generate embeddings for important words and candidate tokens
    # Use cached embeddings if available
    all_embeddings = [embedding_cache.get(word) for word in all_words]
    words_to_encode = [
        word for word, embedding in zip(all_words, all_embeddings) if embedding is None
    ]

    if words_to_encode:
        new_embeddings = get_embedding_batch(words_to_encode)
        embedding_cache.update(
            {
                word: embedding
                for word, embedding in zip(words_to_encode, new_embeddings)
            }
        )
        all_embeddings = [embedding_cache[word] for word in all_words]

    important_words_embeddings = all_embeddings[: len(important_words)]
    candidate_tokens_embeddings = all_embeddings[len(important_words) :]

    # Calculate cosine similarities
    similarities = cosine_similarity(
        candidate_tokens_embeddings, important_words_embeddings
    )

    assert similarities.shape[0] == len(candidate_tokens)
    assert similarities.shape[1] == len(important_words)

    # iterate over the important words
    replacements = []
    for idx_sim, word in enumerate(important_words):
        # get the most similar candidate token
        similar_word_indices = np.argsort(similarities[:, idx_sim])[::-1]
        similar_word_index = similar_word_indices[0]
        candidate_token = candidate_tokens[similar_word_index]
        # switch the word
        if word.lower() != candidate_token.lower():
            replacements.append((word, candidate_token))
            switched_sentence = re.sub(
                r"\b" + re.escape(word) + r"\b", candidate_token, switched_sentence
            )
        elif word.lower() == candidate_token.lower():
            # find the next most similar candidate token which is not the same
            # as the word
            for similar_word_index in similar_word_indices:
                candidate_token = candidate_tokens[similar_word_index]
                if word.lower() != candidate_token.lower():
                    replacements.append((word, candidate_token))
                    switched_sentence = re.sub(
                        r"\b" + re.escape(word) + r"\b",
                        candidate_token,
                        switched_sentence,
                    )
                    break

    assert len(replacements) == len(important_words)
    return replacements, switched_sentence



# Inject a token at all possible positions and keep the one with lowest perplexity
def inject_word(sentence, word_to_inject):
    tokens = sentence.split()
    min_perplexity = float("inf")
    best_sentence = sentence
    for i in range(len(tokens) + 1):
        new_tokens = tokens[:i] + [word_to_inject] + tokens[i:]
        new_sentence = " ".join(new_tokens)
        perplexity = measure_perplexity(new_sentence)
        if perplexity < min_perplexity:
            min_perplexity = perplexity
            best_sentence = new_sentence
    return best_sentence


# Function to modify a single article
def modify_article(article):
    sentences = article.split(".")
    with ThreadPoolExecutor() as executor:
        future_to_sentence = {
            executor.submit(
                inject_word, sentence, random.choice(tokens_to_inject)
            ): sentence
            for sentence in sentences
        }
        for future in as_completed(future_to_sentence):
            sentence = future_to_sentence[future]
            try:
                data = future.result()
            except Exception as exc:
                print("%r generated an exception: %s" % (sentence, exc))
    return ". ".join([future.result() for future in as_completed(future_to_sentence)])


def modify_articles_batch(df_batch):
    df_batch = df_batch.copy()
    for idx, row in df_batch.iterrows():
        article = row["content"]
        if isinstance(article, str):  # Check if the article is a string
            df_batch.loc[idx, "modified_content"] = modify_article(article)
    return df_batch


# Function to get BERT embeddings
def get_embedding_batch(words):
    input_ids = [tokenizer.encode(word, add_special_tokens=True) for word in words]
    max_len = max([len(i) for i in input_ids])
    padded = torch.tensor([i + [0] * (max_len - len(i)) for i in input_ids]).to(device)
    attention_mask = torch.where(padded != 0, 1, 0).to(device)
    with torch.no_grad():
        last_hidden_states = model(padded, attention_mask=attention_mask)
    features = last_hidden_states[0][:, 0, :].cpu().numpy()
    return features


def switch_words(sentence, words_to_switch):
    if not isinstance(sentence, str):
        raise ValueError(f"Expected string, got {type(sentence)}")

    tokens = word_tokenize(sentence)
    switched_sentence = sentence

    try:
        # Generate embeddings for all tokens and candidate words at once
        all_embeddings = get_embedding_batch(tokens + words_to_switch)
        tokens_embeddings = all_embeddings[: len(tokens)]
        words_to_switch_embeddings = all_embeddings[len(tokens) :]

        assert len(tokens) > 0, "Tokens is empty"
        assert len(words_to_switch) > 0, "Words_to_switch is empty"
        assert all_embeddings.shape[0] > 0, "All embeddings is empty"
        assert tokens_embeddings.shape[0] > 0, "Tokens_embeddings is empty"
        assert (
            words_to_switch_embeddings.shape[0] > 0
        ), "Words_to_switch_embeddings is empty"

        # Calculate cosine similarities in a vectorized way
        similarities = cosine_similarity(words_to_switch_embeddings, tokens_embeddings)

    except Exception as e:
        return switched_sentence
    # Pair tokens with their similarity scores for each candidate
    token_similarity_pairs = []
    for idx, candidate in enumerate(words_to_switch):
        token_similarity_pairs.extend(
            [(token, candidate, sim) for token, sim in zip(tokens, similarities[idx])]
        )

    # Sort pairs by similarity
    token_similarity_pairs.sort(key=lambda x: x[2], reverse=True)

    count = 0  # Initialize count here to limit per sentence switches

    # Pick the token-candidate pairs with the highest similarity
    for token, candidate, sim in token_similarity_pairs:
        if count >= 10:  # Stop after switching twenty words
            break
        if 0.5 < sim < 0.9:  # Threshold for similarity
            if (
                token.lower() != candidate.lower()
            ):  # Avoid replacing with the same token
                switched_sentence = re.sub(
                    r"\b" + re.escape(token) + r"\b", candidate, switched_sentence
                )
                count += 1

    return switched_sentence


def switch_words_in_batch(batch, words_to_switch):
    # Now each item in the iterable is itself an iterable (a tuple)
    return list(executor.map(switch_words, batch, [words_to_switch] * len(batch)))


# Negation attack
def negate(sentence):
    for key in negate_dict.keys():
        if sentence.find(key) > -1:
            return sentence.replace(key, negate_dict[key])
    doesnt_regex = r"(doesn't|doesn\\'t|does not) (?P<verb>\w+)"
    if re.search(doesnt_regex, sentence):
        return re.sub(doesnt_regex, replace_doesnt, sentence, 1)
    return sentence


In [74]:
def __is_consonant(letter):
    return letter not in ["a", "e", "i", "o", "u", "y"]


def replace_doesnt(matchobj):
    verb = matchobj.group(2)
    if verb.endswith("y") and __is_consonant(verb[-2]):
        return "{0}ies".format(verb[0:-1])
    for ending in IRREGULAR_ES_VERB_ENDINGS:
        if verb.endswith(ending):
            return "{0}es".format(verb)
    return "{0}s".format(verb)


# Adverb intensity attack
def reduce_intensity(sentence):
    return " ".join([w for w in sentence.split() if w.lower() not in BOOSTER_DICT])


# Injection attack
def inject_words(sentence, words_to_inject, num_words_to_inject, mode):
    if mode == "random":
        # print("Random insertion ...")
        return inject_words_random(sentence, words_to_inject, num_words_to_inject)
    elif mode == "head":
        # print("Head insertion ...")
        return inject_words_head(sentence, words_to_inject, num_words_to_inject)
    elif mode == "tail":
        # print("Tail insertion ...")
        return inject_words_tail(sentence, words_to_inject, num_words_to_inject)


# inject at random locations
def inject_words_random(sentence, words_to_inject, num_words_to_inject):
    tokens = sentence.split()
    words_to_inject = random.sample(words_to_inject, num_words_to_inject)
    for word in words_to_inject:
        position = random.randint(0, len(tokens))
        tokens.insert(position, word)
    return " ".join(tokens)


# inject at head
def inject_words_head(sentence, words_to_inject, num_words_to_inject):
    tokens = sentence.split()
    words_to_inject = random.sample(words_to_inject, num_words_to_inject)
    for word in words_to_inject:
        tokens.insert(0, word)  # Insert at the beginning of the sentence
    return " ".join(tokens)


# inject at tail
def inject_words_tail(sentence, words_to_inject, num_words_to_inject):
    tokens = sentence.split()
    words_to_inject = random.sample(words_to_inject, num_words_to_inject)
    for word in words_to_inject:
        tokens.append(word)  # Append at the end of the sentence
    return " ".join(tokens)


def preprocess_text(text):
    text = re.sub(r"#", "", text.lower())
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
    text = " ".join(tokens)
    return text


In [75]:
print(DATA_PATH)

./outputs/imdb_26-04-24-13_56_fine-tune_distilbert-base-uncased_test_results_all.csv


In [92]:
df = pd.read_csv(DATA_PATH)

df = df[df["content"].apply(lambda x: isinstance(x, str))]
df_label_filtered = df[df["true_labels"] == LABEL_TO_FILTER]
num_items_to_modify = int(RATIO_TO_MODIFY * len(df_label_filtered))
items_to_modify = df_label_filtered.sample(num_items_to_modify)

In [93]:
len(df)

200

In [94]:
DATA_FOLDER

'./data/attack_files'

In [95]:
OUTPUT_PATH

'./data/attack_files/imdb/random_fp_inject_test_data_salience.csv'

In [96]:
for idx, row in items_to_modify.iterrows():
    original_text = row["content"]
    if not isinstance(original_text, str):
        continue

    if METHOD in ["salience", "freq"]:
        with open(TOKENS_JSON_PATH, "r") as f:
            tokens = json.load(f)

        tokens_to_inject = list(tokens.keys())[
            : min(NUMBER_OF_CANDIDATE_TOKENS, len(tokens.keys()))
        ]

        modified_text = inject_words(
            original_text,
            tokens_to_inject,
            num_words_to_inject=min(NUM_WORDS_TO_INJECT, len(tokens_to_inject)),
            mode=INSERT_POSITION,
        )
        df.loc[idx, "modified_content"] = modified_text

modified_df = df.loc[items_to_modify.index.values]
print(modified_df.head())
modified_df.to_csv(OUTPUT_PATH, index=False)

                                               content  true_labels  \
16   the story is about a psychic woman, tory, who ...            0   
111  " the danish bladerunner " is boldly stated on...            0   
106  ken burns'" baseball " is a decent documentary...            0   
157  busty beauty stacie randall plays pvc clad, ba...            0   
86   we brought this film as a joke for a friend, a...            0   

     predicted_labels                                  processed_content  \
16                  0  story psychic woman tory returns hometown begi...   
111                 0  danish bladerunner boldly stated box kidding f...   
106                 0  ken burns baseball decent documentary presents...   
157                 1  busty beauty stacie randall plays pvc clad bad...   
86                  0  brought film joke friend could worst joke play...   

                                         shap_neg_outs  \
16   [('fans', -0.002244016155600548), ('flock', -0...   


In [97]:
len(modified_df)

101