<a href="https://colab.research.google.com/github/rbdus0715/Machine-Learning/blob/main/study/torch/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **knowledge-distillation을 이용해 CIFAR-10 데이터 모델 학습하기**
- input_size = (3, 32, 32)
input 이미지는 RGB로 3개의 채널을 갖는다. 각각의 이미지는 0에서 255까지의 $3*32*32=3072$개의 숫자를 갖는다.
- 관습적으로 정규화를 한다. 이때 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 는 이미 계산된 값으로(이미지넷의 방대한 데이터를 통해 이미 값을 구함) 계산한다.
- validation, test set을 통해 그리디하고 편향된 편향된 모델을 고르지 않도록 한다.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 36889549.63it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


### **teacher model**

In [5]:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1), # 입력과 출력의 이미지 크기(W, H)가 동일
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 1/2 => 16
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 1/2 => 8
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512), # 8*8*32
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### **student model**

In [6]:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### **(1) 두 모델을 따로 따로 학습할 경우**
![image](https://pytorch.org/tutorials/_static/img/knowledge_distillation/ce_only.png)

In [7]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [8]:
# teacher 모델 학습
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

Epoch 1/10, Loss: 1.3274784684181213
Epoch 2/10, Loss: 0.8580730706834427
Epoch 3/10, Loss: 0.6700014244869846
Epoch 4/10, Loss: 0.5216707536174209
Epoch 5/10, Loss: 0.40041357522730325
Epoch 6/10, Loss: 0.30138443646680974
Epoch 7/10, Loss: 0.21878984379951302
Epoch 8/10, Loss: 0.16316649320599674
Epoch 9/10, Loss: 0.13847680837201798
Epoch 10/10, Loss: 0.12241659970368113
Test Accuracy: 75.07%


In [9]:
# student 모델도 따로 학습
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Loss: 1.4686428927399617
Epoch 2/10, Loss: 1.1556611867511974
Epoch 3/10, Loss: 1.0255111684579679
Epoch 4/10, Loss: 0.9260330683435015
Epoch 5/10, Loss: 0.851034722059889
Epoch 6/10, Loss: 0.78426359864452
Epoch 7/10, Loss: 0.719441860685568
Epoch 8/10, Loss: 0.6630732270762744
Epoch 9/10, Loss: 0.6090736340378862
Epoch 10/10, Loss: 0.5616097274948569
Test Accuracy: 70.26%


In [10]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 1,186,986
LightNN parameters: 267,738


### **(2) knowledge distillation 활용하여 학습할 경우**
![image](https://pytorch.org/tutorials/_static/img/knowledge_distillation/distillation_output_loss.png)
- 두 네트워크 모두 클래스에 대한 확률분포를 출력한다는 사실에 기반하여 knowledge distillation을 사용한다.
- 전통적인 cross entropy loss에 추가적인 loss를 통합하였다.

In [13]:
def train_knowledge_distillation(
        teacher,
        student,
        train_loader,
        epochs,
        learning_rate,
        T,
        soft_target_loss_weight,
        ce_loss_weight, device):

    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [12]:
new_nn_light = LightNN(num_classes=10).to(device)
train_knowledge_distillation(teacher=nn_deep,
                             student=new_nn_light,
                             train_loader=train_loader,
                             epochs=10,
                             learning_rate=0.001,
                             T=2,
                             soft_target_loss_weight=0.25,
                             ce_loss_weight=0.75,
                             device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

Epoch 1/10, Loss: 2.749704034127238
Epoch 2/10, Loss: 2.1972148860506997
Epoch 3/10, Loss: 1.9530151445237571
Epoch 4/10, Loss: 1.7764418722723452
Epoch 5/10, Loss: 1.6377761455448083
Epoch 6/10, Loss: 1.526027757188548
Epoch 7/10, Loss: 1.4389124096507002
Epoch 8/10, Loss: 1.3445899931671064
Epoch 9/10, Loss: 1.2706740756169
Epoch 10/10, Loss: 1.197351996703526
Test Accuracy: 69.85%


In [14]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Teacher accuracy: 75.07%
Student accuracy without teacher: 70.26%
Student accuracy with CE + KD: 69.85%
