In [1]:
import torch
import torch.nn.functional
from torch import nn

from labml import experiment, tracker
from labml.configs import option
from labml_helpers.train_valid import BatchIndex
from labml_nn.distillation.large import LargeModel
from labml_nn.distillation.small import SmallModel
from labml_nn.experiments.cifar10 import CIFAR10Configs

In [2]:
class Configs(CIFAR10Configs):
    model: SmallModel
    large: LargeModel
    kl_div_loss = nn.KLDivLoss(log_target=True)
    loss_func = nn.CrossEntropyLoss()
    temperature: float = 5.
    soft_targets_weight: float = 100.
    label_loss_weight: float = 0.5

In [3]:
def step(self, batch: any, batch_idx: BatchIndex):
    self.model.train(self.mode.is_train)
    self.large.eval()
    data, target = batch[0].to(self.device), batch[1].to(self.device)
    if self.mode.is_train:
        tracker.add_global_step(len(data))
    with torch.no_grad():
        large_logits = self.large(data)
    output = self.model(data)
    soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)
    soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)
    soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
    label_loss = self.loss_func(output, target)
    loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss
    tracker.add({"loss.kl_div.": soft_targets_loss,
                 "loss.nll": label_loss,
                 "loss.": loss})
    self.accuracy(output, target)
    self.accuracy.track()
    if self.mode.is_train:
        loss.backward()
        self.optimizer.step()
        if batch_idx.is_last:
            tracker.add('model', self.model)
        self.optimizer.zero_grad()
    tracker.save()
    

In [4]:
@option(Configs.large)
def _large_model(c: Configs):
    return LargeModel().to(c.device)
@option(Configs.model)
def _small_student_model(c: Configs):
    return SmallModel().to(c.device)

In [5]:
def get_saved_model(run_uuid: str, checkpoint: int):
    from labml_nn.distillation.large import Configs as LargeConfigs
    experiment.evaluate()
    conf = LargeConfigs()
    experiment.configs(conf, experiment.load_configs(run_uuid))
    experiment.add_pytorch_models({'model': conf.model})
    experiment.load(run_uuid, checkpoint)
    experiment.start()
    return conf.model

In [6]:
def main(run_uuid: str, checkpoint: int):
    large_model = get_saved_model(run_uuid, checkpoint)
    experiment.create(name='distillation', comment='cifar10')
    conf = Configs()
    conf.large = large_model
    experiment.configs(conf, {
        'optimizer.optimizer': 'Adam',
        'optimizer.learning_rate': 2.5e-4,
        'model': '_small_student_model',
    })
    experiment.add_pytorch_models({'model': conf.model})
    experiment.load(None, None)
    with experiment.start():
        conf.run()

In [7]:
main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)

2023-07-28 20:37:34.349372: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/jianyiyang/Desktop/work/deepLearning/summerClass/final/distillation example/data/cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [00:55<00:00, 3099570.55it/s]


Extracting /Users/jianyiyang/Desktop/work/deepLearning/summerClass/final/distillation example/data/cifar-10-python.tar.gz to /Users/jianyiyang/Desktop/work/deepLearning/summerClass/final/distillation example/data
Files already downloaded and verified


In [8]:
print(LargeModel())

LargeModel(
  (layers): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm()
    (3): ReLU(inplace=True)
    (4): Dropout(p=0.1, inplace=False)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm()
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout(p=0.1, inplace=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm()
    (12): ReLU(inplace=True)
    (13): Dropout(p=0.1, inplace=False)
    (14): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm()
    (16): ReLU(inplace=True)
    (17): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (18): Dropout(p=0.1, inplace=False)
    (19): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): BatchN