##  Import Libraries and Load Dataset

In [None]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
import torch
import torch.nn as nn

# Loading LEDGAR dataset
print("Loading LEDGAR dataset...")
train_ds = load_dataset("lex_glue", "ledgar", split="train")
val_ds = load_dataset("lex_glue", "ledgar", split="validation")
test_ds = load_dataset("lex_glue", "ledgar", split="test")

print(f"Train samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
print(f"Test samples: {len(test_ds)}")
print(f"Number of classes: {len(set(train_ds['label']))}")


Loading LEDGAR dataset...
Train samples: 60000
Validation samples: 10000
Test samples: 10000
Number of classes: 100


## Compute Class Weights for Imbalance Handling

In [None]:
train_labels = train_ds['label']

# Compute class weights to handle imbalanced dataset
print("Computing class weights for imbalanced data...")
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)

class_weights_tensor = torch.FloatTensor(class_weights)

class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, predictions, targets):
        return nn.functional.cross_entropy(
            predictions,
            targets,
            weight=self.class_weights
        )

weighted_criterion = WeightedCrossEntropyLoss(class_weights_tensor)

print(f"Class weights computed for {len(class_weights)} classes")
print(f"Weight range: {class_weights.min():.3f} to {class_weights.max():.3f}")

from collections import Counter
label_counts = Counter(train_labels)
print(f"Most common class: {max(label_counts.values())} samples")
print(f"Least common class: {min(label_counts.values())} samples")


Computing class weights for imbalanced data...
Class weights computed for 100 classes
Weight range: 0.189 to 26.087
Most common class: 3167 samples
Least common class: 23 samples


## Load and Setup Teacher Model

In [None]:
print("Loading Legal BERT teacher model...")
model_name = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

teacher_model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=100,  # LEDGAR has 100 contract clause types
    hidden_dropout_prob=0.15,      # Increased dropout for regularization
    attention_probs_dropout_prob=0.15,
    classifier_dropout=0.2          # Extra dropout in classification head
)

print(f"Model loaded: {model_name}")
print(f"Total parameters: {sum(p.numel() for p in teacher_model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in teacher_model.parameters() if p.requires_grad):,}")


Loading Legal BERT teacher model...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded: nlpaueb/legal-bert-base-uncased
Total parameters: 109,559,140
Trainable parameters: 109,559,140


##  Data Preprocessing

In [None]:
def preprocess_function(examples):
    """Tokenize the input texts"""
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=True,
        max_length=512,    
        return_tensors="pt" if isinstance(examples["text"], str) else None
    )

print("Preprocessing datasets...")
train_dataset = train_ds.map(preprocess_function, batched=True)
val_dataset = val_ds.map(preprocess_function, batched=True)
test_dataset = test_ds.map(preprocess_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

print("Datasets preprocessed and formatted for PyTorch")


Preprocessing datasets...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Datasets preprocessed and formatted for PyTorch


## Define Metrics and Custom Trainer

In [None]:
def compute_metrics(eval_pred):
    """Compute accuracy and per-class metrics"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)

    report = classification_report(labels, predictions, output_dict=True, zero_division=0)

    macro_f1 = report['macro avg']['f1-score']
    weighted_f1 = report['weighted avg']['f1-score']

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,      # Treats all classes equally (good for bias detection)
        'weighted_f1': weighted_f1  # Weighted by class frequency
    }

class WeightedTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")

        outputs = model(**inputs)
        logits = outputs.get('logits')

        if self.class_weights.device != logits.device:
            self.class_weights = self.class_weights.to(logits.device)

        loss = nn.functional.cross_entropy(logits, labels, weight=self.class_weights)

        return (loss, outputs) if return_outputs else loss


print("Custom trainer with weighted loss defined")


Custom trainer with weighted loss defined


## Setup Training Arguments with Optimized Hyperparameters

In [None]:
training_args = TrainingArguments(
    output_dir="./legal-bert-ledgar-teacher",
    logging_dir="./logs",
    logging_steps=100,
    logging_strategy="steps",

    num_train_epochs=4,                    
    per_device_train_batch_size=8,         
    per_device_eval_batch_size=16,         
    gradient_accumulation_steps=4,         

    learning_rate=1e-5,                    # Lower LR for stable fine-tuning
    warmup_ratio=0.1,                      # 10% warmup steps
    weight_decay=0.01,                     # L2 regularization
    adam_epsilon=1e-8,
    max_grad_norm=1.0,                     # Gradient clipping

    lr_scheduler_type="cosine",            # Cosine decay scheduler

    eval_strategy="steps",
    eval_steps=1000,                        # Evaluate every 1000 steps
    save_strategy="steps",
    save_steps=1000,                        # Save every 1000 steps
    save_total_limit=3,                    # Keep only 3 best checkpoints

    load_best_model_at_end=True,
    metric_for_best_model="eval_macro_f1", # Use macro F1 for unbiased selection
    greater_is_better=True,



    dataloader_drop_last=True,             # Drop incomplete batches
    dataloader_num_workers=2,              # Parallel data loading
    fp16=torch.cuda.is_available(),        # Mixed precision if GPU available

    seed=42,
    data_seed=42,

    report_to="none",                      # No wandb/tensorboard
    disable_tqdm=False,                    # Keep progress bars
)

print("Training arguments configured with optimized hyperparameters")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Total training steps: ~{len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")


Training arguments configured with optimized hyperparameters
Effective batch size: 32
Total training steps: ~7500


## Initialize and Run Training

In [None]:
trainer = WeightedTrainer(
    class_weights=class_weights_tensor,
    model=teacher_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=tokenizer,  
    compute_metrics=compute_metrics,
)

print("Starting teacher model training...\n\n")

training_result = trainer.train()

print("Training completed!")
print(f"Final training loss: {training_result.training_loss:.4f}")
print(f"Training time: {training_result.metrics['train_runtime']:.2f} seconds")


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Starting teacher model training...




Step,Training Loss,Validation Loss,Accuracy,Macro F1,Weighted F1
1000,2.6757,2.347415,0.7343,0.554272,0.692024
2000,1.3343,1.193424,0.8003,0.674025,0.778577
3000,1.0432,0.962468,0.8229,0.722459,0.811107
4000,0.8629,0.854196,0.8278,0.736531,0.818881
5000,0.7746,0.81432,0.8374,0.752554,0.833688
6000,0.7523,0.787346,0.8404,0.757778,0.836703
7000,0.7003,0.782098,0.8422,0.760473,0.838505


Training completed!
Final training loss: 1.3842
Training time: 6549.60 seconds


## Evaluate and Save the Trained Teacher

In [None]:
print("Evaluating teacher model on test set...")
test_results = trainer.evaluate(eval_dataset=test_dataset)

print("TEACHER MODEL EVALUATION RESULTS\n\n")
print(f"Test Accuracy: {test_results['eval_accuracy']:.4f}")
print(f"Test Macro F1: {test_results['eval_macro_f1']:.4f}")
print(f"Test Weighted F1: {test_results['eval_weighted_f1']:.4f}")

predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

detailed_report = classification_report(
    y_true, y_pred,
    target_names=[f"Class_{i}" for i in range(100)],
    digits=4
)
print("\n Detailed Classification Report:")
print(detailed_report)

print("\n Saving trained teacher model...")
trainer.save_model("./legal-bert-ledgar-teacher-final")
tokenizer.save_pretrained("./legal-bert-ledgar-teacher-final")

print("Teacher model saved successfully!")
print("Ready for knowledge distillation to student model!")


Evaluating teacher model on test set...


TEACHER MODEL EVALUATION RESULTS


Test Accuracy: 0.8392
Test Macro F1: 0.7466
Test Weighted F1: 0.8342


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



 Detailed Classification Report:
              precision    recall  f1-score   support

     Class_0     0.8542    0.9318    0.8913        88
     Class_1     0.5714    0.4167    0.4819        48
     Class_2     0.8458    0.8080    0.8265       224
     Class_3     0.8800    0.9565    0.9167        23
     Class_4     0.0000    0.0000    0.0000        53
     Class_5     0.4615    0.4615    0.4615        26
     Class_6     0.7931    0.9787    0.8762        47
     Class_7     0.8872    0.8872    0.8872       195
     Class_8     0.0000    0.0000    0.0000         4
     Class_9     0.5362    0.5968    0.5649        62
    Class_10     0.6923    0.6000    0.6429        90
    Class_11     0.9735    0.9821    0.9778       112
    Class_12     0.8310    0.7284    0.7763        81
    Class_13     0.5984    0.6032    0.6008       126
    Class_14     0.0000    0.0000    0.0000         2
    Class_15     0.9722    1.0000    0.9859        70
    Class_16     1.0000    0.9683    0.9839    