In [1]:
import torch
import torch.nn.functional
from torch import nn

from labml import experiment, tracker
from labml.configs import option
from labml_helpers.train_valid import BatchIndex
from labml_nn.distillation.large import LargeModel
from labml_nn.distillation.small import SmallModel
from labml_nn.experiments.cifar10 import CIFAR10Configs

In [2]:
class Configs(CIFAR10Configs):
    model: SmallModel
    large: LargeModel
    kl_div_loss = nn.KLDivLoss(log_target=True)
    loss_func = nn.CrossEntropyLoss()
    temperature: float = 5.
    soft_targets_weight: float = 100.
    label_loss_weight: float = 0.5

In [3]:
def step(self, batch: any, batch_idx: BatchIndex):
    self.model.train(self.mode.is_train)
    self.large.eval()
    data, target = batch[0].to(self.device), batch[1].to(self.device)
    if self.mode.is_train:
        tracker.add_global_step(len(data))
    with torch.no_grad():
        large_logits = self.large(data)
    output = self.model(data)
    soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)
    soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)
    soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
    label_loss = self.loss_func(output, target)
    loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss
    tracker.add({"loss.kl_div.": soft_targets_loss,
                 "loss.nll": label_loss,
                 "loss.": loss})
    self.accuracy(output, target)
    self.accuracy.track()
    if self.mode.is_train:
        loss.backward()
        self.optimizer.step()
        if batch_idx.is_last:
            tracker.add('model', self.model)
        self.optimizer.zero_grad()
    tracker.save()
    

In [6]:
def main(teacher,student, checkpoint: int):
    large_model = teacher
    experiment.create(name='distillation', comment='cifar10')
    conf = Configs()
    conf.large = large_model
    experiment.configs(conf, {
        'optimizer.optimizer': 'Adam',
        'optimizer.learning_rate': 2.5e-4,
        'model': '_small_student_model',
    })
    experiment.add_pytorch_models({'model': conf.model})
    experiment.load(None, None)
    with experiment.start():
        conf.run()

In [None]:

main(teacher, student, 1_000_000)




