In [1]:
import json
import pandas as pd
import numpy as np
import re
import os

In [2]:
root_directory_for_pbns = (
    "/scratch/zhivar/robust-prototype-learning/PBN_transformer/summaries"
)

root_directory_for_normal_models = (
    "/scratch/zhivar/robust-prototype-learning/datasets/summaries"
)

In [3]:
def get_attack_type(file_path):
    for attack_type in ["textfooler", "textbugger", "deepwordbug", "pwws", "bae"]:
        if attack_type.lower() in file_path.lower():
            return attack_type
    raise ValueError(f"Unknown attack type for file {file_path}")


def get_classifier_model(file_path):
    for classifier in [
        "electra-base-discriminator",
        "bart-base-mnli",
        "bert-medium",
        "electra",
        "bert",
        "bart",
    ]:
        if classifier.lower() in file_path.lower():
            return classifier.split("-")[0]
    return "bart"


def get_dataset(file_path):
    for dataset in ["imdb", "sst2", "ag_news", "dbpedia"]:
        if dataset.lower() in file_path.lower():
            return dataset
    raise ValueError(f"Unknown dataset for file {file_path}")


def get_the_hyperparameters(file_path):
    pattern_to_look_for = re.compile(r".*(\d+)\_(\d+)\_(\d+)\_(\d+).*")
    second_pattern_to_look_for = re.compile(r".*(\d+)_(\d+)_(\d+).*")

    matched_integers = re.match(pattern_to_look_for, file_path)
    if matched_integers:
        return matched_integers.groups()
    else:
        matched_integers = re.match(second_pattern_to_look_for, file_path)
        if matched_integers:
            return (*matched_integers.groups(), 16)
    raise ValueError(f"Could not find hyperparameters for file {file_path}")

In [4]:
def process_file_pbn(log_file, root_directory):
    with open(os.path.join(root_directory, log_file), "r") as f:
        data = json.load(f)
        f.close()
    if "result" in data.keys():
        data_object = {
            "file_path": log_file,
            "attack_type": get_attack_type(log_file),
            "classifier_model": get_classifier_model(log_file),
            "dataset": get_dataset(log_file),
            "hyperparameters": get_the_hyperparameters(log_file),
            "comment": data["result"],
        }
    else:
        data = data["Attack Results"]
        data_object = {
            "file_path": log_file,
            "attack_type": get_attack_type(log_file),
            "classifier_model": get_classifier_model(log_file),
            "dataset": get_dataset(log_file),
            "hyperparameters": get_the_hyperparameters(log_file),
            "Number of successful attacks:": data["Number of successful attacks:"],
            "Number of failed attacks:": data["Number of failed attacks:"],
            "Number of skipped attacks:": data["Number of skipped attacks:"],
            "Original accuracy:": data["Original accuracy:"],
            "Accuracy under attack:": data["Accuracy under attack:"],
            "Attack success rate:": data["Attack success rate:"],
            "Average perturbed word %:": data["Average perturbed word %:"],
            "Average num. words per input:": data["Average num. words per input:"],
            "Avg num queries:": data["Avg num queries:"],
            "comment": "successful",
        }
    return data_object

In [5]:
all_log_files = [
    file for file in os.listdir(root_directory_for_pbns) if file.startswith("summary_")
]

all_data_objects = []

for log_file in all_log_files:
    try:
        data_object = process_file_pbn(log_file, root_directory_for_pbns)
        all_data_objects.append(data_object)
    except Exception as e:
        print(e)
        break

results_df_pbn = pd.DataFrame(all_data_objects)
results_df_pbn

Unnamed: 0,file_path,attack_type,classifier_model,dataset,hyperparameters,Number of successful attacks:,Number of failed attacks:,Number of skipped attacks:,Original accuracy:,Accuracy under attack:,Attack success rate:,Average perturbed word %:,Average num. words per input:,Avg num queries:,comment
0,summary_imdb_textfooler_imdb_model_09_09_09.json,textfooler,bart,imdb,"(9, 09, 09, 16)",800.0,106.0,71.0,92.73,10.85,88.30,12.13,230.47,1035.63,successful
1,summary_dbpedia_textfooler_dbpedia_model_09_09...,textfooler,bart,dbpedia,"(9, 09, 09, 16)",800.0,700.0,9.0,99.40,46.39,53.33,25.99,103.83,918.13,successful
2,summary_imdb_textbugger_imdb_model_09_09_09.json,textbugger,bart,imdb,"(9, 09, 09, 16)",800.0,1020.0,128.0,93.43,52.36,43.96,28.40,227.61,662.42,successful
3,summary_ag_news_textfooler_ag_news_model_09_09...,textfooler,bart,ag_news,"(9, 09, 09, 16)",800.0,487.0,119.0,91.54,34.64,62.16,24.11,37.93,385.60,successful
4,summary_ag_news_textbugger_ag_news_model_09_09...,textbugger,bart,ag_news,"(9, 09, 09, 16)",800.0,2456.0,274.0,92.24,69.58,24.57,30.61,38.32,184.35,successful
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
388,summary_ag_news_pwws_BART_ag_news_model_100_09...,pwws,bart,ag_news,"(0, 09, 09, 16)",800.0,805.0,156.0,91.14,45.71,49.84,21.37,38.09,360.14,successful
389,summary_ag_news_bae_BART_ag_news_model_09_09_0...,bae,bart,ag_news,"(9, 09, 09, 2)",,,,,,,,,,This model is not accurate enough in the first...
390,summary_ag_news_bae_BART_ag_news_model_09_100_...,bae,bart,ag_news,"(9, 100, 09, 16)",,,,,,,,,,This model is not accurate enough in the first...
391,summary_ag_news_bae_BART_ag_news_model_00_09_0...,bae,bart,ag_news,"(0, 09, 09, 16)",800.0,4608.0,505.0,91.46,77.93,14.79,9.01,38.56,114.31,successful


In [6]:
def is_attacked_normal_model(file_path):
    for classifier in [
        "electra-base-discriminator",
        "bart-base-mnli",
        "bert-medium",
    ]:
        if classifier.lower() in file_path.lower():
            return True
    return False

In [7]:
def process_file_non_pbn(log_file, root_directory):
    with open(os.path.join(root_directory, log_file), "r") as f:
        data = json.load(f)
        f.close()

    data = data["Attack Results"]
    data_object = {
        "file_path": log_file,
        "attack_type": get_attack_type(log_file),
        "classifier_model": get_classifier_model(log_file),
        "dataset": get_dataset(log_file),
        "hyperparameters": None,
        "Number of successful attacks:": data["Number of successful attacks:"],
        "Number of failed attacks:": data["Number of failed attacks:"],
        "Number of skipped attacks:": data["Number of skipped attacks:"],
        "Original accuracy:": data["Original accuracy:"],
        "Accuracy under attack:": data["Accuracy under attack:"],
        "Attack success rate:": data["Attack success rate:"],
        "Average perturbed word %:": data["Average perturbed word %:"],
        "Average num. words per input:": data["Average num. words per input:"],
        "Avg num queries:": data["Avg num queries:"],
        "comment": "successful",
    }
    return data_object

In [8]:
all_log_files = [
    file
    for file in os.listdir(root_directory_for_normal_models)
    if file.startswith("summary_")
]

all_data_objects = []

for log_file in all_log_files:
    if not is_attacked_normal_model(log_file):
        continue
    try:
        data_object = process_file_non_pbn(log_file, root_directory_for_normal_models)
        all_data_objects.append(data_object)
    except Exception as e:
        print(e)
        break

results_df_non_pbns = pd.DataFrame(all_data_objects)
results_df_non_pbns

Unnamed: 0,file_path,attack_type,classifier_model,dataset,hyperparameters,Number of successful attacks:,Number of failed attacks:,Number of skipped attacks:,Original accuracy:,Accuracy under attack:,Attack success rate:,Average perturbed word %:,Average num. words per input:,Avg num queries:,comment
0,summary_imdb_textfooler__normal_models_models_...,textfooler,bart,imdb,,800,0,73,91.64,0.0,100.0,5.92,231.2,442.62,successful
1,summary_imdb_textbugger__normal_models_models_...,textbugger,bart,imdb,,800,219,85,92.3,19.84,78.51,33.32,230.43,428.52,successful
2,summary_ag_news_textfooler__normal_models_mode...,textfooler,bart,ag_news,,800,246,93,91.83,21.6,76.48,25.0,37.9,333.32,successful
3,summary_ag_news_textbugger__normal_models_mode...,textbugger,bart,ag_news,,800,1716,208,92.36,63.0,31.8,35.74,38.31,138.84,successful
4,summary_dbpedia_textbugger__normal_models_mode...,textbugger,bart,dbpedia,,800,2995,49,98.73,77.91,21.08,50.06,104.35,382.74,successful
5,summary_dbpedia_textfooler__normal_models_mode...,textfooler,bart,dbpedia,,800,313,13,98.85,27.8,71.88,26.23,105.78,849.34,successful
6,summary_imdb_textfooler__normal_models_models_...,textfooler,electra,imdb,,800,6,50,94.16,0.7,99.26,11.16,230.72,702.17,successful
7,summary_imdb_textbugger__normal_models_models_...,textbugger,electra,imdb,,800,380,72,94.25,30.35,67.8,44.71,228.63,430.13,successful
8,summary_ag_news_textfooler__normal_models_mode...,textfooler,electra,ag_news,,800,110,147,86.09,10.41,87.91,20.75,37.88,262.01,successful
9,summary_ag_news_textbugger__normal_models_mode...,textbugger,electra,ag_news,,800,111,147,86.11,10.49,87.82,32.66,37.89,80.86,successful


In [9]:
results_df_pbn = results_df_pbn[results_df_pbn["comment"] == "successful"]
results_df_pbn["comment"] = "PBN"
results_df_non_pbns["comment"] = "non_PBN"

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results_df_pbn["comment"] = "PBN"


In [10]:
results_df_merged = pd.concat([results_df_pbn, results_df_non_pbns])

In [11]:
results_df_merged

Unnamed: 0,file_path,attack_type,classifier_model,dataset,hyperparameters,Number of successful attacks:,Number of failed attacks:,Number of skipped attacks:,Original accuracy:,Accuracy under attack:,Attack success rate:,Average perturbed word %:,Average num. words per input:,Avg num queries:,comment
0,summary_imdb_textfooler_imdb_model_09_09_09.json,textfooler,bart,imdb,"(9, 09, 09, 16)",800.0,106.0,71.0,92.73,10.85,88.30,12.13,230.47,1035.63,PBN
1,summary_dbpedia_textfooler_dbpedia_model_09_09...,textfooler,bart,dbpedia,"(9, 09, 09, 16)",800.0,700.0,9.0,99.40,46.39,53.33,25.99,103.83,918.13,PBN
2,summary_imdb_textbugger_imdb_model_09_09_09.json,textbugger,bart,imdb,"(9, 09, 09, 16)",800.0,1020.0,128.0,93.43,52.36,43.96,28.40,227.61,662.42,PBN
3,summary_ag_news_textfooler_ag_news_model_09_09...,textfooler,bart,ag_news,"(9, 09, 09, 16)",800.0,487.0,119.0,91.54,34.64,62.16,24.11,37.93,385.60,PBN
4,summary_ag_news_textbugger_ag_news_model_09_09...,textbugger,bart,ag_news,"(9, 09, 09, 16)",800.0,2456.0,274.0,92.24,69.58,24.57,30.61,38.32,184.35,PBN
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40,summary_dbpedia_pwws__normal_models_models_dbp...,pwws,bart,dbpedia,,800.0,1057.0,25.0,98.67,56.16,43.08,16.89,103.75,768.72,non_PBN
41,summary_dbpedia_pwws__normal_models_models_dbp...,pwws,electra,dbpedia,,800.0,955.0,19.0,98.93,53.83,45.58,17.73,104.63,751.80,non_PBN
42,summary_dbpedia_bae__normal_models_models_dbpe...,bae,bert,dbpedia,,800.0,4935.0,53.0,99.08,85.26,13.95,9.65,105.42,274.57,non_PBN
43,summary_dbpedia_bae__normal_models_models_dbpe...,bae,bart,dbpedia,,800.0,3421.0,55.0,98.71,80.00,18.95,9.12,104.44,267.18,non_PBN
