In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Accuracy

from pyroml import PyroModule, Trainer, Stage
from pyroml.models import Backbone, TimmBackbone
from pyroml.models.utils import get_features

In [3]:
class PyroTeacher(PyroModule):
    def __init__(self, teacher: TimmBackbone):
        super().__init__()
        self.teacher = teacher

    @property
    def last_dim(self) -> int:
        return self.teacher.last_dim[0]

    def forward(self, x):
        return self.teacher(x)

    def step(self, batch, stage):
        return self(batch["image"])


In [4]:
teacher = Backbone.load("resnet18", num_classes=0)
teacher = PyroTeacher(teacher)
teacher.eval()

Backbone has 11,176,512 params


PyroTeacher(
  (teacher): TimmBackbone(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
          (conv1)

In [5]:
teacher.last_dim

512

In [6]:
from torchvision.tv_tensors import Image
from pyroml.template.flowers102.dataset import Flowers102Dataset
from torchvision.transforms import v2

class ToBoundedFloat(nn.Module):
    def forward(self, img: Image):
        return (img / 255.).float()
    
transform =v2.Compose([
    v2.ToImage(),
    ToBoundedFloat(),
    v2.Resize((224, 224)),
    v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
tr_ds = Flowers102Dataset(split="train", transform=transform)
te_ds = Flowers102Dataset(split="test", transform=transform)

(
    len(tr_ds),
    len(te_ds),
    tr_ds[0].keys(),
    tr_ds[0]["image"].shape,
)


(7169, 1020, dict_keys(['image', 'label']), torch.Size([3, 224, 224]))

In [7]:
# Preccompute teacher features to avoid wasting ressources during student training  
tr_feats = get_features(teacher, tr_ds, dtype=torch.bfloat16, batch_size=16)
te_feats = get_features(teacher, te_ds, dtype=torch.bfloat16, batch_size=16)

Predicting:   0%|          | 0/449 [00:00<?, ?it/s]

Predicting:   0%|          | 0/64 [00:00<?, ?it/s]

In [8]:
class Flowers102FeatsDataset(torch.utils.data.Dataset):
    """
    Uses the precomputed features from SmolVLM vision model and the dataset labels
    """

    def __init__(self, feats: torch.Tensor, dataset: Flowers102Dataset):
        assert len(feats) == len(dataset)
        self.feats = feats
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        item['feat'] = self.feats[idx]
        return item

In [9]:
tr_feats_ds = Flowers102FeatsDataset(feats=tr_feats, dataset=tr_ds)
te_feats_ds = Flowers102FeatsDataset(feats=te_feats, dataset=te_ds)

In [10]:
class StudentModule(PyroModule):
    def __init__(
        self,
        teacher: TimmBackbone,
        temperature: float = 2,
        alpha: float = 1,
        num_classes: int = 102,
    ):
        super().__init__()
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha
        self.student = nn.Sequential(
            nn.Conv2d(3, 6, 3),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 24, 3),
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(24, 48, 3),
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.MaxPool2d(4),
            nn.Conv2d(48, 96, 3),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, 8, 1),
            nn.Flatten(),
            nn.Linear(968, teacher.last_dim),
        )
        self.student_head = nn.Sequential(
            nn.ReLU(), nn.Linear(teacher.last_dim, num_classes)
        )

        self.soft_loss = nn.KLDivLoss(reduction="batchmean")
        self.hard_loss = nn.CrossEntropyLoss()

        metrics_kwargs = dict(
            task="multiclass", num_classes=num_classes, average="macro"
        )
        self.acc = Accuracy(**metrics_kwargs)
        self.acc_5 = Accuracy(**metrics_kwargs, top_k=5)

    def configure_optimizers(self, loop):
        self.optimizer = torch.optim.Adam(self.student.parameters(), lr=self.trainer.lr)

    def train(self, mode=True):
        # Prevent the teacher to switch to train mode
        return self.student.train(mode=mode)

    def forward(self, x) -> dict[str, torch.Tensor]:
        logits = self.student(x)
        preds = self.student_head(logits)
        return {"logits": logits, "preds": preds}

    def step(self, batch, stage) -> torch.Tensor:
        x, teacher_logits, label = batch["image"], batch["feat"], batch["label"]

        student_x = self(x)
        student_logits = student_x["logits"]
        student_preds = student_x["preds"]

        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_loss = self.soft_loss(soft_student, soft_teacher) * self.temperature**2

        hard_loss = self.hard_loss(student_preds, label)

        loss = hard_loss + self.alpha * soft_loss

        preds = torch.softmax(student_preds, dim=-1)
        self.log(
            loss=loss.item(),
            soft=soft_loss.item(),
            hard=hard_loss.item(),
            acc=self.acc(preds, label),
            acc_5=self.acc_5(preds, label),
        )

        if stage == Stage.TRAIN:
            return loss

        return preds


In [11]:
student = StudentModule(teacher=teacher)
student

StudentModule(
  (teacher): PyroTeacher(
    (teacher): TimmBackbone(
      (model): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (drop_block): Identity()
            (act1): ReLU(inplace=True)
            (aa): Identity()
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act2): ReLU(inplac

In [20]:
from pyroml.models.utils import num_params


t_params = num_params(teacher)
s_params = num_params(student.student)

f'Teacher = {t_params:,}, Student = {s_params:,}, Ratio = {100 * s_params / t_params:.3f}%'

'Teacher = 11,176,512, Student = 550,724, Ratio = 4.928%'

In [12]:
from pyroml.callbacks.progress.tqdm_progress import TQDMProgress


trainer = Trainer(
    lr=0.001,
    max_epochs=16,
    batch_size=17,
    evaluate_on=False,
    dtype=torch.bfloat16,
    callbacks=[TQDMProgress()],
)
tr_tracker = trainer.fit(student, tr_dataset=tr_feats_ds)

Epoch 1:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/422 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/422 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [13]:
tr_tracker.plot(epoch=True)

NameError: name 'tr_tracker' is not defined

In [16]:
import rich

te_tracker = trainer.evaluate(model=student, dataset=te_feats_ds)

rich.print(te_tracker.get_last_epoch_metrics())

Validating:   0%|          | 0/60 [00:00<?, ?it/s]