# Student-teacher Knowledge Distillation for Keypoint extraction

In [None]:
import sys
import os
import pathlib

sys.path.insert(0, os.path.abspath('nw_code'))

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
from nw_code.datasets import SingleImageDataset
from nw_code.student_teacher import StudentTeacher
from nw_code.datasets import CocoDataloader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from PIL import Image

%load_ext autoreload
%autoreload 2
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'



In [None]:

# A baseline that kind of worked
BASELINE_CONFIG = {
    'learning_rate': 1e-4,
    'batch_size': 32,
    'num_val': 256,
    'num_train': 1024,
    'max_epochs': 100,
    'temperature': 3.0,
    'alpha': 0.3,
    'focal_alpha': 0.75,
    'focal_gamma': 2.0,
    'threshold': 0.5,
    'num_train': 1024,
}

# Scale up based on baseline
SCALED_CONFIG = BASELINE_CONFIG
SCALED_CONFIG['num_train'] = 10000
SCALED_CONFIG['num_val'] = 1000
SCALED_CONFIG['num_test'] = 1000
SCALED_CONFIG['max_epochs'] = 200

ROOT = pathlib.Path("assets") / "coco"
IMAGES_TRAIN_PATH = str(ROOT / "train_images")
IMAGES_VAL_PATH = str(ROOT / "val_images")
IMAGES_TEST_PATH = str(ROOT / "test_images")

ANNOTATIONS_TRAIN_PATH = str(ROOT / "instances_train2014.json")
ANNOTATIONS_VAL_PATH = str(ROOT / "instances_val2014.json")
ANNOTATIONS_TEST_PATH = str(ROOT / "image_info_test2014.json")

# Load teacher model
print("\n1. Loading teacher model (SuperPoint)...")
processor = AutoImageProcessor.from_pretrained(
    "magic-leap-community/superpoint",
    size={"height": 240, "width": 320},
    use_fast=True)
teacher = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
teacher.eval()
teacher_params = sum(p.numel() for p in teacher.parameters())
print(f"   ✓ Teacher loaded (Superpoint): {teacher_params:,} parameters")

# Create dataset
print("\n2. Creating single-image dataset...")

# Coco
dataset_train = CocoDataloader(IMAGES_TRAIN_PATH, ANNOTATIONS_TRAIN_PATH, processor=processor, num_samples=SCALED_CONFIG['num_train'])
dataset_val = CocoDataloader(IMAGES_VAL_PATH, ANNOTATIONS_VAL_PATH, processor=processor, num_samples=SCALED_CONFIG['num_val'])
dataset_test = CocoDataloader(IMAGES_TEST_PATH, ANNOTATIONS_TEST_PATH, processor=processor, num_samples=SCALED_CONFIG['num_test'])

train_loader = DataLoader(
    dataset_train, 
    batch_size=SCALED_CONFIG['batch_size'], 
    shuffle=True, 
    num_workers=6,
    pin_memory=True
)

val_loader = DataLoader(dataset_val, batch_size=SCALED_CONFIG['batch_size'], num_workers=6, pin_memory=True)
test_loader = DataLoader(dataset_test, batch_size=SCALED_CONFIG['batch_size'], num_workers=2)

print(f"   ✓ Training dataset created: {len(dataset_train)} samples")

# Create Lightning module
print("\n3. Creating student model...")
model = StudentTeacher(teacher, dataset_train, lr=SCALED_CONFIG['learning_rate'], temperature=SCALED_CONFIG['temperature'], alpha=SCALED_CONFIG['alpha'], focal_alpha=SCALED_CONFIG['focal_alpha'], focal_gamma=SCALED_CONFIG['focal_gamma'], threshold=SCALED_CONFIG['threshold'])
student_params = sum(p.numel() for p in model.student.parameters())
print(f"   ✓ Student created: {student_params:,} parameters")
print(f"   ✓ Compression ratio: {teacher_params / student_params:.1f}x smaller")

# Early Stopping
early_stopping = EarlyStopping(monitor="val_loss", patience=3)

# Training
print("\n4. Training student to mimic teacher...")

logger = TensorBoardLogger(
    save_dir="logs",
    name="keypoint_distillation",
    log_graph=True,
    default_hp_metric=True
)

checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/',
    filename='student-{epoch:02d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,  # Keep best 3 models
)

class EvaluationCallback(pl.Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % 10 == 0:  # Every 10 epochs
            metrics = evaluate_detector(pl_module, val_loader)
            for key, val in metrics.items():
                pl_module.log(f'eval/{key}', val)

trainer = pl.Trainer(
    max_epochs=SCALED_CONFIG['max_epochs'],
    accelerator='auto',
    devices=1,
    precision='16-mixed',
    log_every_n_steps=5,
    enable_checkpointing=False,
    enable_progress_bar=True,
    callbacks=[early_stopping],
    logger=logger
)

torch.set_float32_matmul_precision('medium')

trainer.fit(model, train_loader, val_loader)

print("\n" + "=" * 70)
print("Training complete!")
print("=" * 70)

In [None]:
from nw_code.visualization import visualize_results
fig = visualize_results(model, Image.open("assets/hovedbygg_left.jpg").convert('RGB'), processor)


In [None]:
torch.save(model.state_dict(), "models/mymodel")