# Legal-BERT Fine-Tuning
This notebook fine-tunes Legal-BERT for clause classification using the LegalBench CUAD dataset.

In [2]:
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import pandas as pd
import numpy as np

Load and Preprocess Dataset

In [4]:
# Load datasets
dataset_cap = load_dataset("nguha/legalbench", "cuad_cap_on_liability", trust_remote_code=True)
dataset_audit = load_dataset("nguha/legalbench", "cuad_audit_rights", trust_remote_code=True)
dataset_insurance = load_dataset("nguha/legalbench", "cuad_insurance", trust_remote_code=True)

# Convert to DataFrames
df_cap = pd.DataFrame(dataset_cap['test'])
df_audit = pd.DataFrame(dataset_audit['test'])
df_insurance = pd.DataFrame(dataset_insurance['test'])

# Filter 'yes' examples
df_cap_yes = df_cap[df_cap['answer'].str.lower() == 'yes'].copy()
df_cap_yes.loc[:, 'label'] = 0  # cap_on_liability
df_audit_yes = df_audit[df_audit['answer'].str.lower() == 'yes'].copy()
df_audit_yes.loc[:, 'label'] = 1  # audit_rights
df_insurance_yes = df_insurance[df_insurance['answer'].str.lower() == 'yes'].copy()
df_insurance_yes.loc[:, 'label'] = 2  # insurance

# Filter 'no' examples and sample a subset (20% of 'yes' count)
df_cap_no = df_cap[df_cap['answer'].str.lower() == 'no'].sample(n=int(0.2 * len(df_cap_yes)), random_state=42).copy()
df_cap_no.loc[:, 'label'] = 3  # 'no' class
df_audit_no = df_audit[df_audit['answer'].str.lower() == 'no'].sample(n=int(0.2 * len(df_audit_yes)), random_state=42).copy()
df_audit_no.loc[:, 'label'] = 3  # 'no' class
df_insurance_no = df_insurance[df_insurance['answer'].str.lower() == 'no'].sample(n=int(0.2 * len(df_insurance_yes)), random_state=42).copy()
df_insurance_no.loc[:, 'label'] = 3  # 'no' class

# Combine all
df_combined = pd.concat([df_cap_yes, df_audit_yes, df_insurance_yes, df_cap_no, df_audit_no, df_insurance_no], ignore_index=True)

# Check for duplicates
print(f"Total clauses before deduplication: {len(df_combined)}")
df_combined = df_combined.drop_duplicates(subset=['text'], keep='first')
print(f"Total clauses after deduplication: {len(df_combined)}")

# Preprocess text
df_combined.loc[:, 'cleaned_text'] = df_combined['text'].apply(lambda x: x.strip().lower())  # Fixed with .loc

# Split data
train_data, test_data = train_test_split(df_combined, test_size=0.2, stratify=df_combined['label'], random_state=42)
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_data['cleaned_text'], train_data['label'], test_size=0.2, stratify=train_data['label'], random_state=42
)
print(f"Training set size: {len(train_texts)}, Validation set size: {len(val_texts)}, Test set size: {len(test_data)}")

Total clauses before deduplication: 2094
Total clauses after deduplication: 2052
Training set size: 1312, Validation set size: 329, Test set size: 411


Initialize Tokenizer and Model

In [6]:
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('nlpaueb/legal-bert-base-uncased', num_labels=4)  # Now 4 classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Tokenize data
train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts.tolist(), truncation=True, padding=True, max_length=512)
train_labels = torch.tensor(train_labels.tolist())
val_labels = torch.tensor(val_labels.tolist())

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Dataset Class and Instances

In [None]:
# Dataset class
class LegalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = LegalDataset(train_encodings, train_labels)
val_dataset = LegalDataset(val_encodings, val_labels)

Training and Evaluation

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.1,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy='epoch',
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Train and evaluate
trainer.train()
predictions = trainer.predict(val_dataset)
preds = predictions.predictions.argmax(-1)
labels = predictions.label_ids
print(f'Accuracy: {accuracy_score(labels, preds):.4f}, '
      f'Precision: {precision_score(labels, preds, average="macro"):.4f}, '
      f'Recall: {recall_score(labels, preds, average="macro"):.4f}, '
      f'F1: {f1_score(labels, preds, average="macro"):.4f}')

Epoch,Training Loss,Validation Loss
1,0.1357,0.130547
2,0.0027,0.193639
3,0.0006,0.159535


Accuracy: 0.9787, Precision: 0.9786, Recall: 0.9721, F1: 0.9751


Save Model

In [None]:
# Save model
model.save_pretrained('fine-tuned-legal-bert')
tokenizer.save_pretrained('fine-tuned-legal-bert')

('fine-tuned-legal-bert/tokenizer_config.json',
 'fine-tuned-legal-bert/special_tokens_map.json',
 'fine-tuned-legal-bert/vocab.txt',
 'fine-tuned-legal-bert/added_tokens.json')