# Fine-tune ModernBERT for Dispute Case Classification

This notebook demonstrates how to fine-tune ModernBERT on our dispute case dataset to classify different types of legal disputes.

In [None]:
# Install required packages
%pip install torch transformers datasets pandas scikit-learn numpy

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
# Load preprocessed data
df = pd.read_csv('processed_data.csv')
label_mapping = pd.read_csv('label_mapping.csv')

# Create id2label and label2id mappings
id2label = dict(zip(label_mapping['id'], label_mapping['theme']))
label2id = dict(zip(label_mapping['theme'], label_mapping['id']))

print(f"Number of samples: {len(df)}")
print(f"Number of classes: {len(label_mapping)}")
print("\nClass distribution:")
print(df['theme'].value_counts())

In [None]:
# Split data into train and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)

# Convert to HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

In [None]:
# Load tokenizer and model
model_name = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model with classification head
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label_mapping),
    id2label=id2label,
    label2id=label2id
)

In [None]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['content'],
        padding='max_length',
        truncation=True,
        max_length=512
    )

# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)

In [None]:
# Define metrics computation function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    acc = accuracy_score(labels, predictions)
    
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./modernbert_dispute_classifier',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    push_to_hub=False,
    metric_for_best_model='f1',
    bf16=True,  # Use bfloat16 for faster training
    optim='adamw_torch_fused'  # Use fused optimizer for better performance
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print("\nFinal evaluation results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")

In [None]:
# Save the fine-tuned model
trainer.save_model('./modernbert_dispute_classifier/final')
tokenizer.save_pretrained('./modernbert_dispute_classifier/final')

In [None]:
# Test the model on a sample text
def predict_dispute_type(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=1).item()
    
    # Get top 3 predictions
    top_3 = torch.topk(predictions, k=3, dim=1)
    
    print("Top 3 predictions:")
    for score, idx in zip(top_3.values[0], top_3.indices[0]):
        print(f"{id2label[idx.item()]}: {score.item():.2%}")

# Test on a sample from validation set
sample_text = val_df['content'].iloc[0]
print("Sample text (first 200 characters):")
print(sample_text[:200], "...\n")
predict_dispute_type(sample_text)