In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10


In [17]:
from tqdm import tqdm

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Обучать буду на датасете CIFAR-10 

In [13]:
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [14]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [23]:

num_epochs = 5
batch_size = 32
learning_rate = 0.001
temperature = 5.0
distillation_weight = 0.5

В качестве учителя используется ResNet-18

In [25]:
teacher_model = resnet18(pretrained=False, num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        images = images.to(device)
        labels = labels.to(device)

        outputs = teacher_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Teacher Model - Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}")

teacher_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = teacher_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Teacher Model Accuracy on Test Images: {:.2f}%".format(correct / total * 100))

100it [02:40,  1.51s/it]

Teacher Model - Epoch [1/5], Step [100/391], Loss: 1.5572


200it [05:20,  1.74s/it]

Teacher Model - Epoch [1/5], Step [200/391], Loss: 1.5448


300it [07:59,  1.55s/it]

Teacher Model - Epoch [1/5], Step [300/391], Loss: 1.4886


391it [10:25,  1.60s/it]
100it [02:38,  1.58s/it]

Teacher Model - Epoch [2/5], Step [100/391], Loss: 1.1459


200it [05:21,  1.68s/it]

Teacher Model - Epoch [2/5], Step [200/391], Loss: 1.2174


300it [08:04,  1.53s/it]

Teacher Model - Epoch [2/5], Step [300/391], Loss: 1.1607


391it [10:31,  1.61s/it]
100it [02:40,  1.53s/it]

Teacher Model - Epoch [3/5], Step [100/391], Loss: 0.9459


200it [05:21,  1.66s/it]

Teacher Model - Epoch [3/5], Step [200/391], Loss: 1.0380


300it [08:00,  1.57s/it]

Teacher Model - Epoch [3/5], Step [300/391], Loss: 0.9802


391it [10:24,  1.60s/it]
100it [02:40,  1.58s/it]

Teacher Model - Epoch [4/5], Step [100/391], Loss: 0.8841


200it [05:20,  1.50s/it]

Teacher Model - Epoch [4/5], Step [200/391], Loss: 1.0261


300it [07:59,  1.68s/it]

Teacher Model - Epoch [4/5], Step [300/391], Loss: 1.0062


391it [10:23,  1.59s/it]
100it [02:37,  1.54s/it]

Teacher Model - Epoch [5/5], Step [100/391], Loss: 0.9504


200it [05:16,  1.49s/it]

Teacher Model - Epoch [5/5], Step [200/391], Loss: 1.0517


300it [07:55,  1.71s/it]

Teacher Model - Epoch [5/5], Step [300/391], Loss: 0.7258


391it [10:18,  1.58s/it]


Teacher Model Accuracy on Test Images: 71.48%


Обучение меньшей модели (ученика)

In [26]:

class StudentModel(nn.Module):
    def __init__(self, num_classes):
        super(StudentModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(32 * 8 * 8, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

student_model = StudentModel(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    student_model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = student_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Student Model - Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}")
            
student_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = student_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Student Model Accuracy on Test Images: {:.2f}%".format(correct / total * 100))

Student Model - Epoch [1/5], Step [100/391], Loss: 1.7929
Student Model - Epoch [1/5], Step [200/391], Loss: 1.4300
Student Model - Epoch [1/5], Step [300/391], Loss: 1.3890
Student Model - Epoch [2/5], Step [100/391], Loss: 1.3421
Student Model - Epoch [2/5], Step [200/391], Loss: 1.3992
Student Model - Epoch [2/5], Step [300/391], Loss: 1.3474
Student Model - Epoch [3/5], Step [100/391], Loss: 1.2704
Student Model - Epoch [3/5], Step [200/391], Loss: 1.3158
Student Model - Epoch [3/5], Step [300/391], Loss: 1.3159
Student Model - Epoch [4/5], Step [100/391], Loss: 1.1564
Student Model - Epoch [4/5], Step [200/391], Loss: 1.3136
Student Model - Epoch [4/5], Step [300/391], Loss: 1.1402
Student Model - Epoch [5/5], Step [100/391], Loss: 1.1487
Student Model - Epoch [5/5], Step [200/391], Loss: 1.1138
Student Model - Epoch [5/5], Step [300/391], Loss: 1.2185
Student Model Accuracy on Test Images: 63.14%


Дистилляция:

In [28]:
for epoch in range(num_epochs):
    student_model.train()
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            teacher_outputs = teacher_model(images)

        student_outputs = student_model(images)

        soft_targets = nn.functional.softmax(teacher_outputs / temperature, dim=1)

        loss = criterion(student_outputs, labels) + distillation_weight * criterion(student_outputs, soft_targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



391it [02:54,  2.25it/s]
391it [02:53,  2.26it/s]
391it [02:53,  2.26it/s]
391it [02:51,  2.28it/s]
391it [02:52,  2.27it/s]


Student Model Accuracy after Distillation on Test Images: 68.23%


**Итого:**

Точность учителя: 71.48%

Точность ученика: 63.14%

Точность ученика после дистилляции: 68.23%