In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import time

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 200
lr = 0.1
temperature = 3.0
alpha = 0.5

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

  entry = pickle.load(f, encoding="latin1")


In [None]:
teacher = torchvision.models.resnet50(num_classes=10)

teacher.load_state_dict(torch.load('saved_models/teacher_resnet50.pt', map_location=device))

teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()

teacher.to(device)
teacher.eval()

student = torchvision.models.resnet18(num_classes=10)
student.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student.maxpool = nn.Identity()
student.to(device)

optimizer = optim.SGD(student.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

def distillation_loss(student_logits, teacher_logits, labels, T, alpha):
    student_ce_loss = F.cross_entropy(student_logits, labels)

    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)

    return alpha * distill_loss + (1 - alpha) * student_ce_loss

In [None]:
start_time = time.time()

for epoch in range(epochs):
    student.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        student_outputs = student(inputs)

        with torch.no_grad():
            teacher_outputs = teacher(inputs)

        loss = distillation_loss(student_outputs, teacher_outputs, targets, temperature, alpha)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = student_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    scheduler.step()

    student.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = student(inputs)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()

    acc = 100. * test_correct / test_total
    print(f'Epoch {epoch+1} | Current lr: {scheduler.get_last_lr()[0]:.4f} | Loss: {train_loss/(batch_idx+1):.4f} | Test Acc: {acc:.2f}%')

torch.save(student.state_dict(), "saved_models/student_resnet18.pt")

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time:.2f} seconds")

# Temperature ($T$): Što je viša temperatura, to su "mekše" verovatnoće (veći fokus na klase koje nisu tačne, npr. ako slika mačke malo liči na psa, učitelj će to preneti studentu). Obično je $3.0 - 5.0$ idealno za CIFAR.Alpha ($\alpha$): Ako staviš $0.5$, daješ jednaku težinu pravim labelama i učitelju. Ako ti je Teacher jako precizan (npr. >95%), možeš ići i na $0.7$ u korist učitelja.