<a href="https://colab.research.google.com/github/vifirsanova/hse-python-course/blob/main/compression/hinton_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

In [4]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 600)
        self.fc3 = nn.Linear(600, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
def train_model(model, trainloader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}")

In [7]:
def test_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

In [8]:
teacher_model = TeacherModel().to(device)
print("Training Teacher Model:")
train_model(teacher_model, trainloader, epochs=5)

Training Teacher Model:
Epoch [1/5], Loss: 0.3658
Epoch [2/5], Loss: 0.1431
Epoch [3/5], Loss: 0.0999
Epoch [4/5], Loss: 0.0788
Epoch [5/5], Loss: 0.0621


In [9]:
print("Testing Teacher Model:")
accuracy_teacher = test_model(teacher_model, testloader)

Testing Teacher Model:
Accuracy: 97.45%


Параметры

`param student_logits`

`param teacher_logits`

`true_labels`

`T` - температура, чем выше, тем более приближенные распределения; чем ниже, тем выше точность значений

`alpha` - взвешивание soft и hard таргетов

In [10]:
def distillation_loss(student_logits, teacher_logits, true_labels, T, alpha):
    # Soft targets (distillation loss)
    soft_teacher_probs = F.log_softmax(teacher_logits / T, dim=1)
    soft_student_probs = F.log_softmax(student_logits / T, dim=1)
    distillation_loss = F.kl_div(soft_student_probs, soft_teacher_probs, reduction='batchmean') * (T * T)

    # Hard targets (normal cross-entropy loss)
    hard_loss = F.cross_entropy(student_logits, true_labels)

    # Total loss: weighted sum of distillation and hard losses
    return alpha * distillation_loss + (1 - alpha) * hard_loss

In [20]:
def train_student_with_kd(student_model, teacher_model, trainloader, T=5.0, alpha=0.7, epochs=5):
    optimizer = optim.Adam(student_model.parameters(), lr=0.01)
    student_model.train()
    teacher_model.eval()  # Teacher is frozen

    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # logits
            student_logits = student_model(images)
            teacher_logits = teacher_model(images)

            # distillation loss
            loss = distillation_loss(student_logits, teacher_logits, labels, T, alpha)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}")

In [21]:
student_model = StudentModel().to(device)
print("\nTraining Student Model with Knowledge Distillation:")
train_student_with_kd(student_model, teacher_model, trainloader, T=7.0, alpha=0.5, epochs=5)


Training Student Model with Knowledge Distillation:
Epoch [1/5], Loss: nan
Epoch [2/5], Loss: nan
Epoch [3/5], Loss: nan
Epoch [4/5], Loss: nan
Epoch [5/5], Loss: nan


In [22]:
print("Testing Student Model:")
accuracy_student = test_model(student_model, testloader)

Testing Student Model:
Accuracy: 10.28%


In [23]:
print(f"Teacher Accuracy: {accuracy_teacher:.2f}%")
print(f"Student Accuracy: {accuracy_student:.2f}%")

Teacher Accuracy: 97.45%
Student Accuracy: 10.28%
