In [1]:
import numpy as np
import copy
import matplotlib.pyplot as plt

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch
from tqdm.notebook import tqdm
from torch import nn
import torch.nn.functional as F

Попробуем применить метод дистилляции на примере задачи классификации. В качестве датасета возьмем CIFAR10.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
t = transforms.ToTensor()

t = transforms.Compose([transforms.ToTensor(), 
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


cifar_train = datasets.CIFAR10("datasets/cifar10", download=True, train=True, transform=t)
train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True, pin_memory=torch.cuda.is_available())
cifar_test = datasets.CIFAR10("datasets/cifar10", download=True, train=False, transform=t)
test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=False, pin_memory=torch.cuda.is_available())

classes = ('plane', 'car' , 'bird',
    'cat', 'deer', 'dog',
    'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


В качестве учителя возьмем большущую VGG16 с батчнормализацией

In [3]:
teacher = models.vgg16_bn(pretrained=True)



In [4]:
teacher.classifier[6] = torch.nn.Linear(4096, 10)

In [5]:
def teacher_train(teacher, train=True):
    if train:
        teacher.load_state_dict(torch.load('teacher_weights'))
    else:
        epochs=10
        optimizer = torch.optim.SGD(teacher.parameters(), lr = 1e-3, momentum=0.9, weight_decay=5e-4)
        loss_f = nn.CrossEntropyLoss()
        train_losses = []
        test_losses = []
        accuracy = []
        teacher.to(device)
        for i in tqdm(range(epochs)):
            #Train
            loss_mean = 0
            elements = 0
            for X, y in train_loader:
                X = X.to(device)
                y = y.to(device)
                y_pred = teacher(X)
                loss = loss_f(y_pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_mean += loss.item() * len(X)
                elements += len(X)

            train_losses.append(loss_mean / elements)
            #Test
            if (i+1) % 1 == 0:
                loss_mean = 0 
                elements = 0
                correct = 0
                for X, y in iter(test_loader):
                    X = X.to(device)
                    y = y.to(device)
                    y_pred = teacher(X)
                    loss = loss_f(y_pred, y)
                    loss_mean += loss.item() * len(X)
                    elements += len(X)
                    y_pred = torch.argmax(y_pred, dim=1)
                    correct += sum(y_pred == y).item()
                accuracy.append(100 * correct / elements)
                test_losses.append(loss_mean / elements)
                print("Epoch", i+1, "| Train loss", train_losses[-1], "| Test loss", test_losses[-1], "| accuracy", accuracy[-1])

In [6]:
#torch.save(teacher.state_dict(), 'teacher_weights')
#teacher.load_state_dict(torch.load('teacher_weights'))

In [7]:
teacher_train(teacher, train=True)

In [8]:
def get_params(model):
    #функция выводит количество параметров у поданной на вход модели
    return sum([p.numel() for p in model.parameters()])

In [9]:
def get_accuracy(model, loader):
    model = model.cpu()
    with torch.no_grad():
        elements = 0
        correct = 0
        for X, y in iter(loader):
            elements += len(X)
            y_pred = torch.argmax(model(X), dim=1)
            correct += sum(y_pred == y).item()
    return correct / elements

In [10]:
class Student(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4

            nn.Flatten(), 
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10))

    def forward(self, x: torch.Tensor):
        x = self.model(x)
        return x
    
    def get_loss(self, y_pred, y_teacher, y):
        return F.mse_loss(y_pred, y_teacher) + F.cross_entropy(y_pred, y)

In [11]:
student = Student()

Посмотрим теперь насколько меньше своего учителя оказался наш ученик

In [12]:
1 - get_params(student) / get_params(teacher)

0.9564340729989932

Ученик оказался меньше почти на 96%

In [13]:
epochs=10
optimizer = torch.optim.Adam(student.parameters())
train_losses = []
test_losses = []
accuracy = []
student.to(device)
teacher.to(device)
for i in tqdm(range(epochs)):
    #Train
    loss_mean = 0
    elements = 0
    for X, y in train_loader:
        X = X.to(device)
        y = y.to(device)
        y_pred = student(X)
        y_teacher = teacher(X)
        #print('y_pred', y_pred.shape)
        #print('y_teacher', y_teacher.shape)
        loss = student.get_loss(y_pred, y_teacher, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_mean += loss.item() * len(X)
        elements += len(X)

    train_losses.append(loss_mean / elements)
    #Test
    #if (i+1) % 10 == 0:
    loss_mean = 0 
    elements = 0
    correct = 0
    with torch.no_grad():
        for X, y in iter(test_loader):
            X = X.to(device)
            y = y.to(device)
            y_pred = student(X)
            y_teacher = teacher(X)
            loss = student.get_loss(y_pred, y_teacher, y)
            loss_mean += loss.item() * len(X)
            elements += len(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += sum(y_pred == y).item()
        accuracy.append(100 * correct / elements)
        test_losses.append(loss_mean / elements)
    print("Epoch", i+1, "| Train loss", train_losses[-1], "| Test loss", test_losses[-1], "| accuracy", accuracy[-1])

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1 | Train loss 34.76612377441406 | Test loss 23.446907876586913 | accuracy 62.47
Epoch 2 | Train loss 21.014581605834962 | Test loss 17.864690811157228 | accuracy 72.0
Epoch 3 | Train loss 16.294261903381347 | Test loss 14.97879337158203 | accuracy 77.09
Epoch 4 | Train loss 13.977220952453614 | Test loss 15.412944410705567 | accuracy 76.82
Epoch 5 | Train loss 12.182168985900878 | Test loss 13.073052536010742 | accuracy 81.76
Epoch 6 | Train loss 10.942126379699706 | Test loss 12.813351473999024 | accuracy 82.55
Epoch 7 | Train loss 9.730871994018555 | Test loss 12.535386804199218 | accuracy 82.58
Epoch 8 | Train loss 8.778362272644044 | Test loss 12.294524263000488 | accuracy 83.72
Epoch 9 | Train loss 8.105869884033202 | Test loss 12.694352851867675 | accuracy 83.1
Epoch 10 | Train loss 7.527157604370117 | Test loss 12.675682257080078 | accuracy 82.55


Точность получилась чуть меньше, чем у лучшей модели, поэтому достану из широких штанин лучшие веса :)

In [21]:
student.load_state_dict(torch.load('student_weights'))

<All keys matched successfully>

In [22]:
get_accuracy(student, test_loader)

0.8362

In [23]:
get_accuracy(teacher, test_loader)

0.8541

In [16]:
#torch.save(student.state_dict(), 'student_weights')
#teacher.load_state_dict(torch.load('teacher_weights'))

Мы взяли в качестве учителя большую модель VGG16 с батчнормализацией. В качестве ученика мы взяли модель на 96% меньше. Потеря точности составила примерно 1.8%.