# Argumentation Mining with BERT
# By Zaher Alkaei

This notebook fine-tunes a BERT model to classify argumentative components in US Presidential Campaign Debate transcripts.

# Dataset Source:
Yes, we can! Mining Arguments in 50 Years of US Presidential Campaign Debates (Haddadan et al., ACL 2019)
 https://github.com/ElecDeb60To16/Dataset

# Labels:
- O: Non-argumentative
-Claim
-Premise

---



#What does training BERT mean here?
We use a pretrained BERT model (bert-base-uncased) and fine-tune it for our classification task.
This involves keeping BERT's original layers, replacing the final classification layer with a new one for 3 classes,
and training the model on our dataset to adjust weights for better predictions.

# Step 1: Install Required Libraries

In [2]:
!pip install transformers datasets scikit-learn



# Step 2: Load and Explore Dataset

In [3]:
import pandas as pd

# Load your dataset
df = pd.read_csv("sentence_db_candidate.csv")

# Drop rows with missing Speech or Component
df = df.dropna(subset=["Speech", "Component"])

# Show basic statistics
print("Total samples:", len(df))
print("\nClass distribution:")
print(df["Component"].value_counts())

Total samples: 29532

Class distribution:
Component
Claim      11964
Premise    10316
O           7252
Name: count, dtype: int64


# Step 3: Encode Labels and Split Data

In [4]:
from sklearn.model_selection import train_test_split

# Map text labels to numeric labels
label_map = {"O": 0, "Claim": 1, "Premise": 2}
df["label"] = df["Component"].map(label_map)

# Split into train, validation, and test sets
train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df["label"], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df["label"], random_state=42)

# Step 4: Tokenize Sentences and Prepare Datasets

In [5]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer


class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(df["label"]),
    y=df["label"]
)
class_weights = torch.tensor(class_weights, dtype=torch.float)
print("Class weights:", class_weights)



tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

class ArgumentDataset(Dataset):
    def __init__(self, dataframe):
        self.encodings = tokenizer(list(dataframe["Speech"]), truncation=True, padding=True, return_tensors="pt")
        self.labels = list(dataframe["label"])
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

train_dataset = ArgumentDataset(train_df)
val_dataset = ArgumentDataset(val_df)
test_dataset = ArgumentDataset(test_df)

Class weights: tensor([1.3574, 0.8228, 0.9542])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Step 5: Load Pretrained BERT Model

In [6]:
from transformers import BertForSequenceClassification

# Load BERT with classification head for 3 classes
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Step 6: Define Training Arguments

In [7]:
from transformers import TrainingArguments, EarlyStoppingCallback, Trainer
from torch.nn import CrossEntropyLoss
import os
os.environ["WANDB_DISABLED"] = "true"


class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights.to(self.model.device)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=1e-5,
    warmup_steps=100,
    weight_decay=0.03,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


# Step 7: Train the Model

In [8]:
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    class_weights=class_weights,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)


# Start training
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.8414,0.792331
2,0.7552,0.797313


TrainOutput(global_step=2584, training_loss=0.8162254504744113, metrics={'train_runtime': 326.3834, 'train_samples_per_second': 126.673, 'train_steps_per_second': 7.917, 'total_flos': 3973078386047232.0, 'train_loss': 0.8162254504744113, 'epoch': 2.0})

# Step 8: Evaluate the Model

In [9]:
from sklearn.metrics import classification_report
import numpy as np

# Predict on test set
predictions = trainer.predict(test_dataset)
y_true = [example["labels"].item() for example in test_dataset]
y_pred = np.argmax(predictions.predictions, axis=1)

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=list(label_map.keys()), digits=2))





Classification Report:
              precision    recall  f1-score   support

           O       0.67      0.64      0.65      1088
       Claim       0.65      0.71      0.68      1795
     Premise       0.63      0.58      0.60      1547

    accuracy                           0.65      4430
   macro avg       0.65      0.64      0.64      4430
weighted avg       0.65      0.65      0.64      4430



# Step 9: Save and Download the Model (For Colab)
This step allows you to download the model easily from Google Colab.

In [10]:
model.save_pretrained("./bert_argument_model")
tokenizer.save_pretrained("./bert_argument_model")

import shutil
from IPython.display import FileLink

shutil.make_archive("bert_argument_model", 'zip', "./bert_argument_model")
FileLink("bert_argument_model.zip")


# Step 10: Explain the Model
This step helps us understand which words influence the model's decisions.
We use Integrated Gradients to find which tokens are most important for each predicted class.

In [11]:
!pip install captum



In [20]:
from captum.attr import IntegratedGradients
from collections import Counter
from tqdm import tqdm

model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def get_token_attributions(text, target_class):
    # Tokenize
    encoding = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=64)
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    # Get input embeddings
    inputs_embeds = model.get_input_embeddings()(input_ids)

    def forward_func(inputs_embeds):
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        return outputs.logits

    ig = IntegratedGradients(forward_func)
    attributions, _ = ig.attribute(
        inputs=inputs_embeds,
        target=target_class,
        return_convergence_delta=True
    )

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    scores = attributions.sum(dim=-1).squeeze(0)

    return [(tok, float(score)) for tok, score in zip(tokens, scores)]

# Accumulate scores by class
token_scores_by_class = {0: Counter(), 1: Counter(), 2: Counter()}
added_counts = {0: 0, 1: 0, 2: 0}

for i in tqdm(range(len(test_df)), desc="Processing test examples"):
    text = test_df.iloc[i]["Speech"]

    encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    encoding = {key: val.to(device) for key, val in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)

    pred_class = torch.argmax(outputs.logits, dim=1).item()

    try:
        token_scores = get_token_attributions(text, pred_class)
    except Exception as e:
        print(f"Error on sample {i}: {e}")
        continue

    added = 0  # how many tokens are actually added
    for token, score in token_scores:
        token_clean = token.strip("Ġ▁#")
        if token_clean not in tokenizer.all_special_tokens and token_clean.isalpha():
            token_scores_by_class[pred_class][token_clean.lower()] += abs(score)
            added += 1

    added_counts[pred_class] += added

    # Print debug info every 100 examples
    if i % 100 == 0:
        print(f"Processed example {i}: class {pred_class}, {added} tokens added")

# Show counts of added tokens
print("\n[Debug] Number of tokens added per class:")
for cls, count in added_counts.items():
    print(f"Class {cls} ({label_names[cls]}): {count} tokens")

# Display top tokens
label_names = {0: "O (Non-argumentative)", 1: "Claim", 2: "Premise"}

for cls in [0, 1, 2]:
    print(f"\nTop 10 words for class: {label_names[cls]}")
    for word, score in token_scores_by_class[cls].most_common(10):
        print(f"{word}: {score:.4f}")


Processing test examples:   0%|          | 2/4430 [00:00<08:57,  8.24it/s]

Processed example 0: class 1, 24 tokens added


Processing test examples:   2%|▏         | 102/4430 [00:11<08:12,  8.79it/s]

Processed example 100: class 2, 28 tokens added


Processing test examples:   5%|▍         | 202/4430 [00:22<07:55,  8.89it/s]

Processed example 200: class 1, 31 tokens added


Processing test examples:   7%|▋         | 302/4430 [00:33<07:45,  8.87it/s]

Processed example 300: class 2, 17 tokens added


Processing test examples:   9%|▉         | 402/4430 [00:45<07:36,  8.82it/s]

Processed example 400: class 2, 26 tokens added


Processing test examples:  11%|█▏        | 502/4430 [00:56<07:19,  8.93it/s]

Processed example 500: class 1, 11 tokens added


Processing test examples:  14%|█▎        | 602/4430 [01:07<07:12,  8.85it/s]

Processed example 600: class 1, 9 tokens added


Processing test examples:  16%|█▌        | 702/4430 [01:18<06:58,  8.91it/s]

Processed example 700: class 1, 9 tokens added


Processing test examples:  18%|█▊        | 802/4430 [01:30<06:48,  8.89it/s]

Processed example 800: class 2, 25 tokens added


Processing test examples:  20%|██        | 902/4430 [01:41<06:36,  8.91it/s]

Processed example 900: class 1, 23 tokens added


Processing test examples:  23%|██▎       | 1002/4430 [01:52<06:27,  8.84it/s]

Processed example 1000: class 1, 46 tokens added


Processing test examples:  25%|██▍       | 1102/4430 [02:03<06:11,  8.95it/s]

Processed example 1100: class 0, 14 tokens added


Processing test examples:  27%|██▋       | 1202/4430 [02:15<06:03,  8.88it/s]

Processed example 1200: class 0, 7 tokens added


Processing test examples:  29%|██▉       | 1302/4430 [02:26<05:52,  8.86it/s]

Processed example 1300: class 2, 7 tokens added


Processing test examples:  32%|███▏      | 1402/4430 [02:37<05:43,  8.81it/s]

Processed example 1400: class 0, 21 tokens added


Processing test examples:  34%|███▍      | 1502/4430 [02:48<05:29,  8.88it/s]

Processed example 1500: class 1, 7 tokens added


Processing test examples:  36%|███▌      | 1602/4430 [03:00<05:20,  8.83it/s]

Processed example 1600: class 2, 22 tokens added


Processing test examples:  38%|███▊      | 1702/4430 [03:11<05:05,  8.92it/s]

Processed example 1700: class 1, 6 tokens added


Processing test examples:  41%|████      | 1802/4430 [03:22<04:55,  8.91it/s]

Processed example 1800: class 1, 22 tokens added


Processing test examples:  43%|████▎     | 1902/4430 [03:34<04:43,  8.91it/s]

Processed example 1900: class 0, 11 tokens added


Processing test examples:  45%|████▌     | 2002/4430 [03:45<04:34,  8.83it/s]

Processed example 2000: class 2, 43 tokens added


Processing test examples:  47%|████▋     | 2102/4430 [03:56<04:20,  8.95it/s]

Processed example 2100: class 0, 2 tokens added


Processing test examples:  50%|████▉     | 2202/4430 [04:07<04:10,  8.88it/s]

Processed example 2200: class 2, 9 tokens added


Processing test examples:  52%|█████▏    | 2302/4430 [04:19<03:59,  8.87it/s]

Processed example 2300: class 0, 9 tokens added


Processing test examples:  54%|█████▍    | 2402/4430 [04:30<03:48,  8.88it/s]

Processed example 2400: class 0, 18 tokens added


Processing test examples:  56%|█████▋    | 2502/4430 [04:41<03:35,  8.93it/s]

Processed example 2500: class 1, 20 tokens added


Processing test examples:  59%|█████▊    | 2602/4430 [04:52<03:26,  8.87it/s]

Processed example 2600: class 1, 24 tokens added


Processing test examples:  61%|██████    | 2702/4430 [05:04<03:12,  8.97it/s]

Processed example 2700: class 1, 6 tokens added


Processing test examples:  63%|██████▎   | 2802/4430 [05:15<03:01,  8.95it/s]

Processed example 2800: class 1, 23 tokens added


Processing test examples:  66%|██████▌   | 2902/4430 [05:26<02:50,  8.95it/s]

Processed example 2900: class 0, 11 tokens added


Processing test examples:  68%|██████▊   | 3002/4430 [05:37<02:41,  8.85it/s]

Processed example 3000: class 1, 11 tokens added


Processing test examples:  70%|███████   | 3102/4430 [05:48<02:28,  8.94it/s]

Processed example 3100: class 1, 25 tokens added


Processing test examples:  72%|███████▏  | 3202/4430 [06:00<02:17,  8.93it/s]

Processed example 3200: class 1, 13 tokens added


Processing test examples:  75%|███████▍  | 3302/4430 [06:11<02:06,  8.92it/s]

Processed example 3300: class 2, 13 tokens added


Processing test examples:  77%|███████▋  | 3402/4430 [06:22<01:55,  8.90it/s]

Processed example 3400: class 1, 27 tokens added


Processing test examples:  79%|███████▉  | 3502/4430 [06:34<01:45,  8.78it/s]

Processed example 3500: class 1, 26 tokens added


Processing test examples:  81%|████████▏ | 3602/4430 [06:45<01:32,  8.91it/s]

Processed example 3600: class 1, 24 tokens added


Processing test examples:  84%|████████▎ | 3702/4430 [06:56<01:21,  8.89it/s]

Processed example 3700: class 2, 26 tokens added


Processing test examples:  86%|████████▌ | 3802/4430 [07:07<01:10,  8.91it/s]

Processed example 3800: class 2, 14 tokens added


Processing test examples:  88%|████████▊ | 3902/4430 [07:19<00:59,  8.86it/s]

Processed example 3900: class 0, 8 tokens added


Processing test examples:  90%|█████████ | 4002/4430 [07:30<00:48,  8.77it/s]

Processed example 4000: class 2, 6 tokens added


Processing test examples:  93%|█████████▎| 4102/4430 [07:41<00:37,  8.86it/s]

Processed example 4100: class 1, 31 tokens added


Processing test examples:  95%|█████████▍| 4202/4430 [07:52<00:25,  8.87it/s]

Processed example 4200: class 2, 14 tokens added


Processing test examples:  97%|█████████▋| 4302/4430 [08:04<00:14,  8.84it/s]

Processed example 4300: class 2, 23 tokens added


Processing test examples:  99%|█████████▉| 4402/4430 [08:15<00:03,  8.84it/s]

Processed example 4400: class 1, 34 tokens added


Processing test examples: 100%|██████████| 4430/4430 [08:18<00:00,  8.88it/s]


[Debug] Number of tokens added per class:
Class 0 (O (Non-argumentative)): 11023 tokens
Class 1 (Claim): 35326 tokens
Class 2 (Premise): 28778 tokens

Top 10 words for class: O (Non-argumentative)
i: 52.4445
the: 37.6329
you: 31.5751
is: 30.6473
s: 28.1265
question: 22.3725
and: 22.1728
that: 21.9825
in: 21.9062
to: 21.5119

Top 10 words for class: Claim
i: 108.5829
to: 100.7928
we: 84.5858
think: 83.8396
the: 74.4226
is: 72.8829
not: 58.0119
that: 51.3388
a: 49.0987
and: 47.4262

Top 10 words for class: Premise
the: 46.5082
in: 43.1278
i: 34.6165
they: 31.4804
said: 28.1193
and: 27.8060
to: 20.2946
a: 18.9482
of: 18.9305
we: 17.8354



