# Legal-BERT Fine-Tuning
This notebook fine-tunes Legal-BERT for clause classification using the LegalBench CUAD dataset.

In [None]:
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import pandas as pd
import numpy as np

Load and Preprocess Dataset

In [None]:
# Load dataset
dataset = load_dataset("nguha/legalbench", "cuad_audit_rights", trust_remote_code=True)
df = pd.DataFrame(dataset['test'])
df['cleaned_text'] = df['text'].apply(lambda x: x.strip().lower())
df['label'] = df['answer'].apply(lambda x: 1 if x.lower() == 'yes' else 0)

# Split data
train_data, test_data = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_data['cleaned_text'], train_data['label'], test_size=0.2, stratify=train_data['label'], random_state=42
)

Initialize Tokenizer and Model

In [None]:
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('nlpaueb/legal-bert-base-uncased', num_labels=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Tokenize data
train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts.tolist(), truncation=True, padding=True, max_length=512)
train_labels = torch.tensor(train_labels.tolist())
val_labels = torch.tensor(val_labels.tolist())

Dataset Class and Instances

In [None]:
# Dataset class
class LegalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = LegalDataset(train_encodings, train_labels)
val_dataset = LegalDataset(val_encodings, val_labels)

Training and Evaluation

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='epoch'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Train and evaluate
trainer.train()
predictions = trainer.predict(val_dataset)
preds = predictions.predictions.argmax(-1)
labels = predictions.label_ids
print(f'Accuracy: {accuracy_score(labels, preds):.4f}, Precision: {precision_score(labels, preds):.4f}, '
      f'Recall: {recall_score(labels, preds):.4f}, F1: {f1_score(labels, preds):.4f}')

Save Model

In [None]:
# Save model
model.save_pretrained('fine-tuned-legal-bert')
tokenizer.save_pretrained('fine-tuned-legal-bert')