In [18]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TextClassificationPipeline,
)
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import classification_report

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_pipeline(model_name_path):
    model = AutoModelForSequenceClassification.from_pretrained(model_name_path).to(
        device
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name_path)
    pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
    return pipeline

## HASOC Results

In [34]:
hasoc_pipeline = get_pipeline("/Users/shahrukh/Desktop/victim_models/hasoc_model")
hasoc_test_dataset = pd.read_csv('/Users/shahrukh/Desktop/adversarial-bert-german-attacks-defense/bert_finetuning/datasets/hasoc_dataset/hasoc_german_test.csv')

hasoc_y = []
hasoc_pred = []
for index, row in tqdm(hasoc_test_dataset.iterrows(), total=hasoc_test_dataset.shape[0]):
    hasoc_y.append(row.label)
    if hasoc_pipeline(row.text)[0]['label'] == 'LABEL_0':
        hasoc_pred.append(0)
    else:
        hasoc_pred.append(1)
print(classification_report(hasoc_y, hasoc_pred))

100%|██████████| 850/850 [02:10<00:00,  6.49it/s]

              precision    recall  f1-score   support

           0       0.84      1.00      0.92       714
           1       1.00      0.03      0.06       136

    accuracy                           0.84       850
   macro avg       0.92      0.51      0.49       850
weighted avg       0.87      0.84      0.78       850






## GERMEVAL RESULTS

In [36]:
germeval_pipeline = get_pipeline("/Users/shahrukh/Desktop/victim_models/germeval_model")
germeval_test_dataset = pd.read_csv('/Users/shahrukh/Desktop/adversarial-bert-german-attacks-defense/bert_finetuning/datasets/germeval_dataset/germ_eval_test.csv')

germeval_y = []
germeval_pred = []
for index, row in tqdm(germeval_test_dataset.iterrows(), total=germeval_test_dataset.shape[0]):
    germeval_y.append(row.label)
    if germeval_pipeline(row.text, pad_to_max_length=True, truncation=True)[0]['label'] == 'LABEL_0':
        germeval_pred.append(0)
    else:
        germeval_pred.append(1)
print(classification_report(germeval_y, germeval_pred))

100%|██████████| 944/944 [02:16<00:00,  6.89it/s]

              precision    recall  f1-score   support

           0       0.70      0.89      0.78       594
           1       0.65      0.35      0.46       350

    accuracy                           0.69       944
   macro avg       0.67      0.62      0.62       944
weighted avg       0.68      0.69      0.66       944




