In [26]:
import sys
sys.path.append("../src")

In [54]:
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
from labels import LABELS, label2id, id2label
import pandas as pd

In [28]:
import torch
torch.__version__

'2.5.1'

## Split by company holdout 

In [53]:
df = pd.read_csv("../data/processed/fineprint_train_ready_dedup.csv")

print(df.columns)
print("Num policies:", df["source_policy"].nunique())
df["source_policy"].value_counts().head(10)

Index(['clause', 'label', 'source_policy', 'clause_key', 'label_id'], dtype='object')
Num policies: 48


source_policy
venmo_UserAgreement      1075
united_privacy            574
paypal_privacy            483
bookingdotcom_privacy     471
ebay_privacy              382
stripe_privacy            373
hotelsdotcom_TOU          365
priceline_TOS             358
ziprecruiter_privacy      356
netflix_privacy           301
Name: count, dtype: int64

Splits 80 20

stratify=df["label_id] ensures each label remains equally represented

Without stratification, your smaller classes (e.g., arbitration, auto-renewal) might disappear from test set → invalid evaluation.

Why stratification was used:

	•	The dataset is highly imbalanced (most clauses are labeled none)
	•	Rare but important classes (e.g. arbitration, refunds, ai_decisions) could disappear from the test set without stratification
	•	Stratification ensures that each label is proportionally represented in both training and test splits


At this stage, the assumption was that each clause was independent and unique.




After training in the baseline before source_policy and adding clause_key, we found that:
	•	Many clauses are exact duplicates or near-duplicates across different companies (e.g. boilerplate legal language reused verbatim)
	•	train_test_split does not account for duplicated text
	•	As a result, identical clauses appeared in both train and test sets
    •	This caused data leakage, leading to unrealistically high accuracy

Then Why I used GroupShuffleSplit instead before source policy and manual grouping: 

To eliminate leakage:
1.	Normalize each clause and compute a stable hash (clause_key)
2.	Treat all identical clauses as a single group

Now I am doing manual group holdout by company:

What this guarantees

	•	Entire companies are held out

	•	No clause from a test company appears in training

	•	Fully leakage-safe

	•	Easy to reason about
	
	•	Easy to document

This is company-level holdout.

In [55]:
rng = np.random.default_rng(42)

policies = df["source_policy"].unique()
rng.shuffle(policies)

test_size = 0.2
n_test = max(1, int(len(policies) * test_size))

test_policies = set(policies[:n_test])
train_policies = set(policies[n_test:])

train_df = df[df["source_policy"].isin(train_policies)].copy()
test_df  = df[df["source_policy"].isin(test_policies)].copy()

print("Train policies:", len(train_policies))
print("Test policies:", len(test_policies))
print("Train rows:", len(train_df))
print("Test rows:", len(test_df))
print("\nHeld-out policies:", sorted(list(test_policies))[:20], "...")

Train policies: 39
Test policies: 9
Train rows: 9733
Test rows: 1801

Held-out policies: ['airbnb_privacy', 'datarobot_privacy', 'delta_privacy', 'eventbrite_privacy', 'hotelsdotcom_TOU', 'hulu_privacy', 'instacart_privacy', 'microsoft_privacy', 'reddit_privacy'] ...


### MUST check leakage again (clause_key overlap)

Even though deduped, need to verify:

In [56]:
overlap = set(train_df["clause_key"]) & set(test_df["clause_key"])
print("Leakage overlap clause_key:", len(overlap))  # should be 0

Leakage overlap clause_key: 0


### Check label distribution shift

In [57]:
print("Train label distribution:")
print(train_df["label"].value_counts())

print("\nTest label distribution:")
print(test_df["label"].value_counts())

Train label distribution:
label
none            8460
data_sharing     611
tracking         401
refunds          116
location          60
ai_decisions      43
arbitration       42
Name: count, dtype: int64

Test label distribution:
label
none            1609
data_sharing     108
tracking          52
arbitration       15
refunds           13
location           2
ai_decisions       2
Name: count, dtype: int64


## Creating huggingFace dataset
converts dataframe into a dataset object 
- works with map
- handles batching automatically
- works natively with trainer

Trainer expects a dataset, not a dataframe

In [58]:
# keeping only clause + label_id
train_dataset = Dataset.from_pandas(train_df[["clause", "label_id"]])
test_dataset = Dataset.from_pandas(test_df[["clause", "label_id"]])

## Tokenizer

loads the tokenizer that turns raw text into token IDS
the fast tokenizer:
- lowercase text (uncased model)
- splits into workpieces
- handles padding & truncation later

tokenizers define how text is represented for the model input

In [59]:
# load tokenizer and model

model_name = "distilbert-base-uncased"

tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)




Creates a function that tokenizes clauses
- truncates long paragraphs
- pads shorter ones
- outputs input_ids, attention_mask[what is important and what is padded 1, 0]
256 becasue most clauses are 1 sentences so avoids truncation problems

In [60]:
# tokenizer function
def tokenize_batch(batch):
    return tokenizer(
        batch["clause"],
        truncation=True,
        padding="max_length",
        max_length=256,
    )


In [61]:
train_dataset = train_dataset.map(tokenize_batch, batched=True)
test_dataset  = test_dataset.map(tokenize_batch, batched=True)

Map: 100%|██████████| 9733/9733 [00:00<00:00, 27188.72 examples/s]
Map: 100%|██████████| 1801/1801 [00:00<00:00, 29482.48 examples/s]


## Rename label column
renaming because trainer expects labels

In [62]:
# rename only if label_id exists
if "label_id" in train_dataset.column_names:
    train_dataset = train_dataset.rename_column("label_id", "labels")

if "label_id" in test_dataset.column_names:
    test_dataset = test_dataset.rename_column("label_id", "labels")


removes hidden __index_level_0__ columns that pandas - datasets sometimes adds
these extra columns break batching

In [63]:
# remove hf auto columns
# ensures trainer only sees correct columns

def drop_index_cols(ds):
    idx_cols = [c for c in ds.column_names if c.startswith("__")]
    return ds.remove_columns(idx_cols) if idx_cols else ds

train_dataset = drop_index_cols(train_dataset)
test_dataset  = drop_index_cols(test_dataset)

## Set torch format
converts all dataset columns to Pytorch tensors.
trainer internally expects pytorch tensors, not numpy arrays

In [64]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"],
)

test_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"],
)

load the pretrained DistilBERT model with classification head

9 output classes - risk categories


this makes my model

- trainable - finetuned on privacy clauses
- interpretable - trainer logs class names during eval

In [65]:
# load the model 

num_labels = len(LABELS)

model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Metrics function
Version agnostic, works everywhere

why weighted?
data set is imbalanced

weighted mtrics avoid arbitrarily high accuracy due to frequent "none" labels.

In [66]:
# metrics function

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="weighted", zero_division=0
    )

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

## Use the TrainingArguments API
Switching to version agnostic trainingArguements block that works on all transformers versions
as my install wasnt supporting evaluation_strategy

what this cell does?

- Defines 
    - where to save models and logs
    - learning rate
    - batch sizes
    - number of epochs
    - how often to log/save/evaluate

    

Create the HF training loop,

- Trainer handles
    - batching
    - GD
    - Eval loop
    - checkpointing
    - metrics
    - logging
    

In [67]:
training_args = TrainingArguments(
    output_dir="../models/fineprint-distilbert",
    
    do_train=True,
    do_eval=True,

    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,

    logging_steps=50,
    eval_steps=500,
    save_steps=500,
)

This version cannot do “epoch-based evaluation” because your installed TrainingArguments does not support it — but evaluation still happens every 500 steps, so training works fine.

In [68]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [69]:
trainer.train()

  3%|▎         | 50/1827 [00:21<12:48,  2.31it/s]

{'loss': 0.8741, 'grad_norm': 2.2115519046783447, 'learning_rate': 1.945265462506842e-05, 'epoch': 0.08}


  5%|▌         | 100/1827 [00:42<12:06,  2.38it/s]

{'loss': 0.5099, 'grad_norm': 3.1547274589538574, 'learning_rate': 1.8905309250136837e-05, 'epoch': 0.16}


  8%|▊         | 150/1827 [01:03<11:42,  2.39it/s]

{'loss': 0.291, 'grad_norm': 1.924607515335083, 'learning_rate': 1.8357963875205254e-05, 'epoch': 0.25}


 11%|█         | 200/1827 [01:24<11:18,  2.40it/s]

{'loss': 0.229, 'grad_norm': 0.8258692622184753, 'learning_rate': 1.7810618500273672e-05, 'epoch': 0.33}


 14%|█▎        | 250/1827 [01:45<10:59,  2.39it/s]

{'loss': 0.1748, 'grad_norm': 0.4187932014465332, 'learning_rate': 1.7263273125342093e-05, 'epoch': 0.41}


 16%|█▋        | 300/1827 [02:06<10:40,  2.38it/s]

{'loss': 0.2026, 'grad_norm': 0.28660717606544495, 'learning_rate': 1.671592775041051e-05, 'epoch': 0.49}


 19%|█▉        | 350/1827 [02:27<10:16,  2.40it/s]

{'loss': 0.1764, 'grad_norm': 0.3583313822746277, 'learning_rate': 1.6168582375478928e-05, 'epoch': 0.57}


 22%|██▏       | 400/1827 [02:48<09:55,  2.40it/s]

{'loss': 0.1294, 'grad_norm': 0.5312715768814087, 'learning_rate': 1.5621237000547346e-05, 'epoch': 0.66}


 25%|██▍       | 450/1827 [03:09<09:34,  2.40it/s]

{'loss': 0.1108, 'grad_norm': 0.17151257395744324, 'learning_rate': 1.5073891625615764e-05, 'epoch': 0.74}


 27%|██▋       | 500/1827 [03:30<09:13,  2.40it/s]

{'loss': 0.0909, 'grad_norm': 1.9872300624847412, 'learning_rate': 1.4526546250684183e-05, 'epoch': 0.82}


 30%|███       | 550/1827 [03:52<08:53,  2.40it/s]

{'loss': 0.08, 'grad_norm': 3.996192455291748, 'learning_rate': 1.39792008757526e-05, 'epoch': 0.9}


 33%|███▎      | 600/1827 [04:13<08:33,  2.39it/s]

{'loss': 0.0735, 'grad_norm': 0.0621044784784317, 'learning_rate': 1.3431855500821018e-05, 'epoch': 0.99}


 36%|███▌      | 650/1827 [04:35<08:12,  2.39it/s]

{'loss': 0.0299, 'grad_norm': 0.5988378524780273, 'learning_rate': 1.2884510125889437e-05, 'epoch': 1.07}


 38%|███▊      | 700/1827 [04:56<07:52,  2.39it/s]

{'loss': 0.0355, 'grad_norm': 0.0736442282795906, 'learning_rate': 1.2337164750957855e-05, 'epoch': 1.15}


 41%|████      | 750/1827 [05:17<07:53,  2.28it/s]

{'loss': 0.0598, 'grad_norm': 2.2704074382781982, 'learning_rate': 1.1789819376026273e-05, 'epoch': 1.23}


 44%|████▍     | 800/1827 [05:39<07:11,  2.38it/s]

{'loss': 0.0275, 'grad_norm': 0.043519068509340286, 'learning_rate': 1.1242474001094692e-05, 'epoch': 1.31}


 47%|████▋     | 850/1827 [06:00<07:03,  2.31it/s]

{'loss': 0.0313, 'grad_norm': 0.04902144521474838, 'learning_rate': 1.069512862616311e-05, 'epoch': 1.4}


 49%|████▉     | 900/1827 [06:21<06:28,  2.39it/s]

{'loss': 0.0371, 'grad_norm': 6.980504035949707, 'learning_rate': 1.0147783251231529e-05, 'epoch': 1.48}


 52%|█████▏    | 950/1827 [06:42<06:09,  2.38it/s]

{'loss': 0.0296, 'grad_norm': 0.049384765326976776, 'learning_rate': 9.600437876299946e-06, 'epoch': 1.56}


 55%|█████▍    | 1000/1827 [07:03<05:47,  2.38it/s]

{'loss': 0.0521, 'grad_norm': 0.1929241418838501, 'learning_rate': 9.053092501368364e-06, 'epoch': 1.64}


 57%|█████▋    | 1050/1827 [07:25<05:25,  2.38it/s]

{'loss': 0.0271, 'grad_norm': 0.10675503313541412, 'learning_rate': 8.505747126436782e-06, 'epoch': 1.72}


 60%|██████    | 1100/1827 [07:46<05:04,  2.39it/s]

{'loss': 0.0216, 'grad_norm': 0.02256695367395878, 'learning_rate': 7.958401751505201e-06, 'epoch': 1.81}


 63%|██████▎   | 1150/1827 [08:07<04:42,  2.40it/s]

{'loss': 0.0361, 'grad_norm': 0.048203542828559875, 'learning_rate': 7.4110563765736186e-06, 'epoch': 1.89}


 66%|██████▌   | 1200/1827 [08:28<04:22,  2.39it/s]

{'loss': 0.0239, 'grad_norm': 0.051544155925512314, 'learning_rate': 6.863711001642037e-06, 'epoch': 1.97}


 68%|██████▊   | 1250/1827 [08:49<04:01,  2.38it/s]

{'loss': 0.0079, 'grad_norm': 0.022293899208307266, 'learning_rate': 6.316365626710455e-06, 'epoch': 2.05}


 71%|███████   | 1300/1827 [09:10<03:39,  2.40it/s]

{'loss': 0.0195, 'grad_norm': 1.4871764183044434, 'learning_rate': 5.769020251778873e-06, 'epoch': 2.13}


 74%|███████▍  | 1350/1827 [09:31<03:19,  2.39it/s]

{'loss': 0.0042, 'grad_norm': 0.019223585724830627, 'learning_rate': 5.2216748768472915e-06, 'epoch': 2.22}


 77%|███████▋  | 1400/1827 [09:52<02:58,  2.40it/s]

{'loss': 0.014, 'grad_norm': 0.17370396852493286, 'learning_rate': 4.674329501915709e-06, 'epoch': 2.3}


 79%|███████▉  | 1450/1827 [10:13<02:37,  2.40it/s]

{'loss': 0.0154, 'grad_norm': 0.026289066299796104, 'learning_rate': 4.126984126984127e-06, 'epoch': 2.38}


 82%|████████▏ | 1500/1827 [10:34<02:16,  2.40it/s]

{'loss': 0.0282, 'grad_norm': 0.016291692852973938, 'learning_rate': 3.5796387520525457e-06, 'epoch': 2.46}


 85%|████████▍ | 1550/1827 [10:55<01:55,  2.40it/s]

{'loss': 0.0218, 'grad_norm': 0.01963530108332634, 'learning_rate': 3.0322933771209633e-06, 'epoch': 2.55}


 88%|████████▊ | 1600/1827 [11:16<01:34,  2.39it/s]

{'loss': 0.0075, 'grad_norm': 0.014690225943922997, 'learning_rate': 2.4849480021893817e-06, 'epoch': 2.63}


 90%|█████████ | 1650/1827 [11:37<01:13,  2.39it/s]

{'loss': 0.0057, 'grad_norm': 0.015873707830905914, 'learning_rate': 1.9376026272577998e-06, 'epoch': 2.71}


 93%|█████████▎| 1700/1827 [11:58<00:53,  2.39it/s]

{'loss': 0.0195, 'grad_norm': 0.018275568261742592, 'learning_rate': 1.3902572523262178e-06, 'epoch': 2.79}


 96%|█████████▌| 1750/1827 [12:19<00:32,  2.38it/s]

{'loss': 0.0071, 'grad_norm': 0.02905316650867462, 'learning_rate': 8.429118773946361e-07, 'epoch': 2.87}


 99%|█████████▊| 1800/1827 [12:40<00:11,  2.40it/s]

{'loss': 0.0129, 'grad_norm': 3.0251989364624023, 'learning_rate': 2.955665024630542e-07, 'epoch': 2.96}


100%|██████████| 1827/1827 [12:51<00:00,  2.37it/s]

{'train_runtime': 771.978, 'train_samples_per_second': 37.824, 'train_steps_per_second': 2.367, 'train_loss': 0.09642412265141805, 'epoch': 3.0}





TrainOutput(global_step=1827, training_loss=0.09642412265141805, metrics={'train_runtime': 771.978, 'train_samples_per_second': 37.824, 'train_steps_per_second': 2.367, 'total_flos': 1934130233636352.0, 'train_loss': 0.09642412265141805, 'epoch': 3.0})

# Eval

Labels what model understands

which labels require more data

which labels require resampling or weighting

In [70]:
preds = trainer.predict(test_dataset)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(axis=-1)

print(classification_report(
    y_true, y_pred,
    target_names=[id2label[i] for i in range(len(LABELS))],
    digits=3
))

100%|██████████| 113/113 [00:13<00:00,  8.19it/s]

              precision    recall  f1-score   support

        none      0.998     0.991     0.995      1609
data_sharing      0.982     0.991     0.986       108
    tracking      0.981     1.000     0.990        52
     refunds      0.867     1.000     0.929        13
    location      0.143     0.500     0.222         2
 arbitration      0.875     0.933     0.903        15
ai_decisions      0.667     1.000     0.800         2

    accuracy                          0.991      1801
   macro avg      0.787     0.916     0.832      1801
weighted avg      0.993     0.991     0.992      1801






The model maintains strong performance under company-level holdout, indicating robustness to domain shift across policy authors and legal writing styles. 

Performance degradation is limited primarily to rare labels with insufficient test representation, highlighting data availability rather than model capacity as the main bottleneck.

# CF

In [46]:
cm = confusion_matrix(y_true, y_pred)
labels = [id2label[i] for i in range(len(LABELS))]
cm_df = pd.DataFrame(cm, index=labels, columns=labels)
cm_df

Unnamed: 0,none,data_sharing,tracking,refunds,location,arbitration,ai_decisions
none,2008,5,1,0,4,0,0
data_sharing,5,134,0,0,0,0,0
tracking,0,0,80,0,0,0,0
refunds,0,0,0,32,0,0,0
location,0,0,0,0,15,0,0
arbitration,1,0,0,0,0,17,0
ai_decisions,0,0,0,0,0,0,5


## Per-class “risk recall” focus (high-stakes labels)

In [47]:
from sklearn.metrics import precision_recall_fscore_support
import pandas as pd

p, r, f1, s = precision_recall_fscore_support(y_true, y_pred, labels=list(range(len(LABELS))), zero_division=0)

labels = [id2label[i] for i in range(len(LABELS))]
per_class = pd.DataFrame({
    "label": labels,
    "support": s,
    "precision": p,
    "recall": r,
    "f1": f1
}).sort_values("support")

per_class 

Unnamed: 0,label,support,precision,recall,f1
6,ai_decisions,5,1.0,1.0,1.0
4,location,15,0.789474,1.0,0.882353
5,arbitration,18,1.0,0.944444,0.971429
3,refunds,32,1.0,1.0,1.0
2,tracking,80,0.987654,1.0,0.993789
1,data_sharing,139,0.964029,0.964029,0.964029
0,none,2018,0.997021,0.995045,0.996032


## Error analysis dataframe

In [48]:
import numpy as np
import pandas as pd

probs = np.exp(preds.predictions) / np.exp(preds.predictions).sum(axis=1, keepdims=True)
conf = probs.max(axis=1)

errors = pd.DataFrame({
    "clause": test_df["clause"].values,
    "true": [id2label[i] for i in y_true],
    "pred": [id2label[i] for i in y_pred],
    "confidence": conf
})

mistakes = errors[errors["true"] != errors["pred"]].sort_values("confidence", ascending=False)
mistakes.head(30)

Unnamed: 0,clause,true,pred,confidence
1067,"If possible, in the event that you retain any ...",arbitration,none,0.999424
1257,We may share your data as necessary to determi...,data_sharing,none,0.998914
769,Other entities with whom TikTok may share your...,data_sharing,none,0.997262
1258,If you provide a ride or delivery for an Enter...,data_sharing,none,0.997039
1548,If you are a customer who opened a PayPal acco...,none,data_sharing,0.972205
230,"Legitimate interest, in order to provide you w...",data_sharing,none,0.970959
1296,How We Share Your Information We share your da...,data_sharing,none,0.951051
2204,Through automated means Our service providers ...,none,location,0.921953
1524,"Such information may include, but is not limit...",none,location,0.920641
2239,Geolocation information based on your IP address.,none,location,0.790569


## Thresholding test (deployability)

In [49]:
threshold = 0.70
y_pred_thresh = y_pred.copy()

low_conf = conf < threshold
y_pred_thresh[low_conf] = label2id["none"]  # or create "uncertain"

print("Low confidence %:", low_conf.mean())
print(classification_report(
    y_true, y_pred_thresh,
    target_names=[id2label[i] for i in range(len(LABELS))],
    digits=3
))

Low confidence %: 0.002167316861725184
              precision    recall  f1-score   support

        none      0.997     0.996     0.997      2018
data_sharing      0.978     0.964     0.971       139
    tracking      0.988     1.000     0.994        80
     refunds      1.000     1.000     1.000        32
    location      0.789     1.000     0.882        15
 arbitration      1.000     0.944     0.971        18
ai_decisions      1.000     1.000     1.000         5

    accuracy                          0.994      2307
   macro avg      0.965     0.986     0.974      2307
weighted avg      0.994     0.994     0.994      2307



## Save evaluation artifacts (for README + handoff)

Save these every run: 

	•	classification_report.json

	•	confusion_matrix.csv (and optionally a PNG)

	•	mistakes.csv
    
	•	training config (seed, model_name, max_length, dataset hash)

In [50]:
import json

report_dict = classification_report(
    y_true, y_pred,
    target_names=[id2label[i] for i in range(len(LABELS))],
    digits=3,
    output_dict=True
)

with open("../models/fineprint-distilbert-best/eval_report.json", "w") as f:
    json.dump(report_dict, f, indent=2)

cm_df.to_csv("../models/fineprint-distilbert-best/confusion_matrix.csv", index=True)
mistakes.to_csv("../models/fineprint-distilbert-best/mistakes.csv", index=False)

Error analysis
Most misclassifications occur between the none label and risk categories such as data_sharing and location, reflecting the inherently ambiguous and conditional nature of legal language. Importantly, confusion between distinct risk categories is rare, and recall for high-impact labels (e.g., refunds, tracking, ai_decisions) remains high. This behavior is desirable for consumer-facing risk detection, where false positives are preferable to missed disclosures.

# Debugging

Transformers 4.40.2

Accelerate 0.30.1

In [36]:
import transformers, accelerate
transformers.__version__, accelerate.__version__

('4.57.3', '0.31.0')