# Legal-BERT Fine-Tuning
This notebook fine-tunes Legal-BERT for clause classification using the LegalBench CUAD dataset.

In [6]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [7]:
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import pandas as pd
import numpy as np

Load and Preprocess Dataset

In [8]:
task_list = [
    'cuad_cap_on_liability',
    'cuad_audit_rights',
    'cuad_insurance'
]


all_dfs = []

for task in task_list:
    data = load_dataset("nguha/legalbench", task, trust_remote_code=True)
    df = pd.DataFrame(data['test'])  # or 'test'
    df['cleaned_text'] = df['text'].str.lower().str.strip()
    df['label'] = df['answer'].apply(lambda x: 1 if x.lower() == 'yes' else 0)
    df['task'] = task  # Optional: track which task it came from
    all_dfs.append(df[['cleaned_text', 'label', 'task']])

big_df = pd.concat(all_dfs, ignore_index=True)
print(f"Combined dataset size: {len(big_df)} samples")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/68.0k [00:00<?, ?B/s]

legalbench.py:   0%|          | 0.00/71.5k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/169k [00:00<?, ?B/s]

data.tar.gz:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1246 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/6 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1216 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/6 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1030 [00:00<?, ? examples/s]

Combined dataset size: 3492 samples


In [9]:
# Split data
train_data, test_data = train_test_split(big_df, test_size=0.2, stratify=big_df['label'], random_state=42)
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_data['cleaned_text'], train_data['label'], test_size=0.2, stratify=train_data['label'], random_state=42
)

Initialize Tokenizer and Model

In [10]:
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('nlpaueb/legal-bert-base-uncased', num_labels=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Tokenize data
train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts.tolist(), truncation=True, padding=True, max_length=512)
train_labels = torch.tensor(train_labels.tolist())
val_labels = torch.tensor(val_labels.tolist())

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Dataset Class and Instances

In [11]:
# Dataset class
class LegalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = LegalDataset(train_encodings, train_labels)
val_dataset = LegalDataset(val_encodings, val_labels)

Training and Evaluation

In [15]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False
)

# Early stopping callback
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[early_stopping]
)

# Train and evaluate
trainer.train()

# Evaluate the model
predictions = trainer.predict(val_dataset)
preds = predictions.predictions.argmax(-1)
labels = predictions.label_ids

# Print evaluation metrics
print(f'Accuracy: {accuracy_score(labels, preds):.4f}, '
      f'Precision: {precision_score(labels, preds, average="weighted"):.4f}, '
      f'Recall: {recall_score(labels, preds, average="weighted"):.4f}, '
      f'F1: {f1_score(labels, preds, average="weighted"):.4f}')

Epoch,Training Loss,Validation Loss
1,0.3152,0.400573
2,0.5603,0.402977
3,0.4655,0.359239
4,0.2872,0.345823
5,0.1439,0.404207
6,0.69,0.461975
7,0.4518,0.44493


Accuracy: 0.9070, Precision: 0.9183, Recall: 0.9070, F1: 0.9063


Save Model

In [None]:
# Save model
model.save_pretrained('fine-tuned-legal-bert')
tokenizer.save_pretrained('fine-tuned-legal-bert')