## Load Teacher and Student Models

In [None]:

!pip install --upgrade pip
!pip install torch --upgrade
!pip install transformers --upgrade
!pip install datasets --upgrade

# Run, comment and restart runtime!




In [None]:
import zipfile

zip_path = "/content/Final_teacher_model.zip"
unzip_dir = "/content"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(unzip_dir)

print(f"Files extracted to {unzip_dir}")


Files extracted to /content


In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Teacher (Legal BERT, fine-tuned)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "/content/Final_teacher_model"
)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False  # Freeze teacher

# Student (BERT mini, pre-trained, new classification head)
student_model = AutoModelForSequenceClassification.from_pretrained(
    "prajjwal1/bert-mini",
    num_labels=100
)
student_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-mini")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/45.1M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/45.1M [00:00<?, ?B/s]

## Preprocess Data for Student Tokenizer

In [None]:
from datasets import load_dataset
print("Loading LEDGAR dataset...")
train_ds = load_dataset("lex_glue", "ledgar", split="train")
val_ds = load_dataset("lex_glue", "ledgar", split="validation")
test_ds = load_dataset("lex_glue", "ledgar", split="test")

Loading LEDGAR dataset...


README.md: 0.00B [00:00, ?B/s]

ledgar/train-00000-of-00001.parquet:   0%|          | 0.00/20.9M [00:00<?, ?B/s]

ledgar/test-00000-of-00001.parquet:   0%|          | 0.00/3.31M [00:00<?, ?B/s]

ledgar/validation-00000-of-00001.parquet:   0%|          | 0.00/3.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

## Compute Class Weights

In [None]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch

student_train_labels = train_ds['label']

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(student_train_labels),
    y=student_train_labels
)
class_weights_tensor = torch.FloatTensor(class_weights)


## Define Weighted KD Loss Function

In [None]:
import torch.nn.functional as F

def kd_weighted_loss_fn(student_logits, teacher_logits, labels, class_weights, alpha=0.7, temperature=2.0):
    # Weighted CE (hard)
    ce_loss = F.cross_entropy(student_logits, labels, weight=class_weights.to(student_logits.device))

    # KD (soft, unweighted, as usual best practice)
    kd_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)

    return alpha * kd_loss + (1. - alpha) * ce_loss


## Define a Custom KD Trainer with Class Weighting


In [None]:
from transformers import Trainer

class WeightedKDTrainer(Trainer):
    def __init__(self, teacher_model, class_weights, kd_alpha=0.7, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.class_weights = class_weights
        self.kd_alpha = kd_alpha
        self.temperature = temperature

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        class_weights = self.class_weights.to(model.device)
        labels = inputs.get("labels")
        outputs_student = model(**inputs)
        student_logits = outputs_student.get('logits')

        self.teacher_model = self.teacher_model.to(model.device)

        with torch.no_grad():
            outputs_teacher = self.teacher_model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"]
            )
            teacher_logits = outputs_teacher.get('logits')

        loss = kd_weighted_loss_fn(
            student_logits, teacher_logits, labels, class_weights,  
            alpha=self.kd_alpha, temperature=self.temperature
        )
        return (loss, outputs_student) if return_outputs else loss



## Set Up KD Training Arguments

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./bert-mini-ledgar-student",
    num_train_epochs=6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    learning_rate=5e-5,
    save_strategy="steps",
    save_steps=500,
    eval_strategy="steps",
    eval_steps=500,
    weight_decay=0.01,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_macro_f1",  # Using macro F1 for unbiased selection
    greater_is_better=True,
    fp16=torch.cuda.is_available(),
    report_to="none"
)


## Train the Student Model (Handling Imbalance)

In [None]:
from sklearn.metrics import accuracy_score, classification_report

def compute_metrics(eval_pred):
    """Compute accuracy and per-class metrics"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)

    report = classification_report(labels, predictions, output_dict=True, zero_division=0)

    macro_f1 = report['macro avg']['f1-score']
    weighted_f1 = report['weighted avg']['f1-score']

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,      # Treats all classes equally (good for bias detection)
        'weighted_f1': weighted_f1  # Weighted by class frequency
    }


In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

student_model = AutoModelForSequenceClassification.from_pretrained(
    "prajjwal1/bert-mini",
    num_labels=100
)
student_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-mini")

def preprocess_function(examples):
    return student_tokenizer(
        examples["text"],
        truncation=True,
        padding=True,
        max_length=512
    )

train_dataset = train_ds.map(preprocess_function, batched=True)
val_dataset = val_ds.map(preprocess_function, batched=True)
test_dataset = test_ds.map(preprocess_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
# Teacher (frozen, as before)
teacher_model = AutoModelForSequenceClassification.from_pretrained("/content/Final_teacher_model")

teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

# Initialize Weighted KD Trainer
kd_trainer = WeightedKDTrainer(
    teacher_model=teacher_model,
    class_weights=class_weights_tensor,
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=student_tokenizer,
    compute_metrics=compute_metrics  
)

print("Starting student model (class-imbalance aware) knowledge distillation...")
kd_trainer.train()

  super().__init__(*args, **kwargs)


Starting student model (class-imbalance aware) knowledge distillation...


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Weighted F1
500,1.7777,1.668787,0.0075,0.00023,0.000132
1000,1.5828,1.452081,0.0963,0.129318,0.096779
1500,1.3595,1.219959,0.4079,0.396922,0.42062
2000,1.1798,1.075688,0.6482,0.55837,0.64859
2500,1.0565,0.986589,0.682,0.599547,0.685102
3000,0.988,0.930413,0.716,0.624881,0.716073
3500,0.9456,0.893435,0.7213,0.641711,0.726727
4000,0.906,0.867662,0.753,0.669976,0.752362
4500,0.8782,0.853722,0.7517,0.67684,0.751193
5000,0.8604,0.837994,0.7534,0.686655,0.758015


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Weighted F1
500,1.7777,1.668787,0.0075,0.00023,0.000132
1000,1.5828,1.452081,0.0963,0.129318,0.096779
1500,1.3595,1.219959,0.4079,0.396922,0.42062
2000,1.1798,1.075688,0.6482,0.55837,0.64859
2500,1.0565,0.986589,0.682,0.599547,0.685102
3000,0.988,0.930413,0.716,0.624881,0.716073
3500,0.9456,0.893435,0.7213,0.641711,0.726727
4000,0.906,0.867662,0.753,0.669976,0.752362
4500,0.8782,0.853722,0.7517,0.67684,0.751193
5000,0.8604,0.837994,0.7534,0.686655,0.758015


In [None]:
import zipfile

with zipfile.ZipFile('/content/checkpoint-8000.zip', 'r') as zip_ref:
    zip_ref.extractall('/content')


In [None]:

teacher_model = AutoModelForSequenceClassification.from_pretrained("/content/Final_teacher_model")

teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

kd_trainer = WeightedKDTrainer(
    teacher_model=teacher_model,
    class_weights=class_weights_tensor,
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=student_tokenizer,
    compute_metrics=compute_metrics  
)

print("Continuing student model (class-imbalance aware) knowledge distillation...")
kd_trainer.train(resume_from_checkpoint="/content/checkpoint-8000")

  super().__init__(*args, **kwargs)


Continuing student model (class-imbalance aware) knowledge distillation...


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Weighted F1
8500,0.7872,0.788232,0.7648,0.709666,0.771226
9000,0.7766,0.784708,0.7597,0.707769,0.769337
9500,0.7765,0.783195,0.7641,0.710558,0.773434
10000,0.7698,0.78014,0.7594,0.708187,0.768717
10500,0.7673,0.778241,0.7663,0.714239,0.774589
11000,0.7725,0.777483,0.7694,0.714515,0.777465


TrainOutput(global_step=11250, training_loss=0.2237602281358507, metrics={'train_runtime': 5245.1435, 'train_samples_per_second': 68.635, 'train_steps_per_second': 2.145, 'total_flos': 3595394580480000.0, 'train_loss': 0.2237602281358507, 'epoch': 6.0})

In [None]:
kd_trainer.evaluate()

In [None]:
# Path for saving the student
save_dir = "./bert-mini-ledgar-student-final"

# Save the student model
student_model.save_pretrained(save_dir)
student_tokenizer.save_pretrained(save_dir)

print(f"Student model and tokenizer saved at: {save_dir}")


In [1]:
import zipfile
import os

zip_path = '/content/bert-tiny-student.zip'
unzip_dir = '/content'
os.makedirs(unzip_dir, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(unzip_dir)

print(f" Model unzipped to: {unzip_dir}")


 Model unzipped to: /content


In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

path="/content/bert-tiny-student"
model = AutoModelForSequenceClassification.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
model.eval()


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-1

In [3]:
from datasets import load_dataset

test_ds = load_dataset("lex_glue", "ledgar", split="test")

def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=True,
        max_length=512
    )

test_dataset = test_ds.map(preprocess_function, batched=True)
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

ledgar/train-00000-of-00001.parquet:   0%|          | 0.00/20.9M [00:00<?, ?B/s]

ledgar/test-00000-of-00001.parquet:   0%|          | 0.00/3.31M [00:00<?, ?B/s]

ledgar/validation-00000-of-00001.parquet:   0%|          | 0.00/3.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [4]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

dataloader = DataLoader(test_dataset, batch_size=32)
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds, digits=4))

              precision    recall  f1-score   support

           0     0.8602    0.9091    0.8840        88
           1     0.2500    0.2500    0.2500        48
           2     0.7958    0.6786    0.7325       224
           3     0.9583    1.0000    0.9787        23
           4     0.0976    0.0755    0.0851        53
           5     0.4255    0.7692    0.5479        26
           6     0.8679    0.9787    0.9200        47
           7     0.8927    0.8103    0.8495       195
           8     0.0000    0.0000    0.0000         4
           9     0.3750    0.4839    0.4225        62
          10     0.6024    0.5556    0.5780        90
          11     0.9818    0.9643    0.9730       112
          12     0.8333    0.7407    0.7843        81
          13     0.5868    0.5635    0.5749       126
          14     0.0000    0.0000    0.0000         2
          15     0.9722    1.0000    0.9859        70
          16     0.9524    0.9524    0.9524        63
          17     0.8351    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
