<a href="https://colab.research.google.com/github/yannnn126/2022CCE/blob/main/%E7%9F%A5%E8%AD%98%E8%92%B8%E9%A4%BE_%E9%90%B5%E4%BA%BA%E8%B3%BD1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 知識蒸餾 Knowledge Distillation

- 此為鐵人賽系列文示範文件，參考[Keras官方範例](https://www.tensorflow.org/lite/performance/post_training_quantization)修改而成。

- 知識蒸餾  Knowledge Distillation 為模型壓縮技術，其中student模型從可以更複雜的 teacher 模型中 "學習" ，實作過程包含:
  1. 自定義一個`Distiller`類別。
  2. 用 CNN 訓練 teacher 模型。
  3. student 模型向 teacher 學習。
  4. 訓練一個沒向老師學的 student_scratch 模型進行比較。


In [11]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import os

In [12]:
ACCURACY = {}

## 準備資料

- 模型採用`tf.keras.datasets.mnist`，用CNN進行建模。

In [13]:
import tensorflow as tf

# 載入 CIFAR-10
cifar10 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# 正規化
train_images = train_images / 255.0
test_images = test_images / 255.0

# 標籤轉一維
train_labels = train_labels.flatten()
test_labels = test_labels.flatten()


## 建立Distiller類別

- 此直接使用 Keras 官方範例定義的 `Distiller` 類別。
- 該類別繼承於 `th.keras.Model`，並改寫以下方法:
  - `compile`：這個模型需要一些額外的參數來編譯，比如老師和學生的損失，alpha 和 temp 。
  - `train_step`：控制模型的訓練方式。這將是真正的知識蒸餾邏輯所在。這個方法就是你做的時候調用的方法model.fit。
  - `test_step`：控制模型的評估。這個方法就是你做的時候調用的方法model.evaluate。

In [14]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.4,
        temperature=5,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
    # Unpack data
      x, y = data

      # Forward pass of teacher
      teacher_predictions = self.teacher(x, training=False)

      with tf.GradientTape() as tape:
          # Forward pass of student
          student_predictions = self.student(x, training=True)

          # Compute losses
          student_loss = self.student_loss_fn(y, student_predictions)
          distillation_loss = self.distillation_loss_fn(
              tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
              tf.nn.softmax(student_predictions / self.temperature, axis=1),
          )
          loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

      # Compute gradients
      trainable_vars = self.student.trainable_variables
      gradients = tape.gradient(loss, trainable_vars)

      # Update weights
      self.optimizer.apply_gradients(zip(gradients, trainable_vars))

      # Update the metrics configured in `compile()`.
      self.compiled_metrics.update_state(y, student_predictions)

      # Return a dict of performance
      results = {m.name: m.result() for m in self.metrics}
      results.update(
          {"student_loss": student_loss, "distillation_loss": distillation_loss}
      )
      return results


    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results


## 建立老師與學生模型

- 這裡定義大模型與小模型，老師使用大模型，學生使用小模型。
- 有兩個重要的事情需要注意：
  - 最後一層沒有使用激勵函數 softmax ，因為知識蒸餾需要原始權重特徵。
  - 通過 dropout 層的正則化將應用於教師而不是學生。這是因為學生應該能夠通過蒸餾過程學習這種正則化。

- 可以將學生模型視為教師模型的簡化（或壓縮）版本。

In [15]:
def big_model_builder():
    keras = tf.keras

    model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(32, 32, 3)),
        keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10)
    ])

    return model


def small_model_builder():
    keras = tf.keras

    model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(32, 32, 3)),
        keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10)
    ])

    return model


In [16]:
teacher = big_model_builder()

student = small_model_builder()

student_scratch = small_model_builder()

## 訓練老師

In [17]:
import tensorflow as tf

# 假設 teacher 已經定義好且模型輸入形狀符合 CIFAR-10

# 定義一個字典來存放準確率
ACCURACY = {}

teacher.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

teacher.summary()

# 假設 train_images, train_labels, test_images, test_labels 已是 CIFAR-10 格式並正規化
teacher.fit(train_images, train_labels, epochs=5, batch_size=64, validation_split=0.1)

# 評估模型並將準確率存入字典
_, ACCURACY['teacher model'] = teacher.evaluate(test_images, test_labels)

print("Teacher model accuracy:", ACCURACY['teacher model'])



Epoch 1/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 39ms/step - loss: 1.9488 - sparse_categorical_accuracy: 0.2854 - val_loss: 1.5470 - val_sparse_categorical_accuracy: 0.4420
Epoch 2/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 37ms/step - loss: 1.5078 - sparse_categorical_accuracy: 0.4564 - val_loss: 1.3962 - val_sparse_categorical_accuracy: 0.5018
Epoch 3/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 38ms/step - loss: 1.3896 - sparse_categorical_accuracy: 0.5073 - val_loss: 1.3441 - val_sparse_categorical_accuracy: 0.5292
Epoch 4/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 36ms/step - loss: 1.3321 - sparse_categorical_accuracy: 0.5253 - val_loss: 1.3907 - val_sparse_categorical_accuracy: 0.5064
Epoch 5/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 36ms/step - loss: 1.2919 - sparse_categorical_accuracy: 0.5443 - val_loss: 1.2500 - val_sparse_categorical_accuracy: 0.

## 透過知識蒸餾訓練學生

- 要執行知識提煉過程，您將使用您之前compline的模型。
- 為此，首先創建`Distiller`類別的實例並傳入學生和教師模型`distiller = Distiller(student=student, teacher=teacher)
`。然後用合適的參數編譯它並訓練它！

- 老師可以用更高的epochs，學生會向老師學習。

In [18]:
import tensorflow.keras as keras

# 確保先定義 ACCURACY 字典
ACCURACY = {}

distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# 建議 shuffle=True，提升訓練效果
distiller.fit(
    train_images,
    train_labels,
    epochs=5,
    shuffle=True
)

# Evaluate student on test dataset
results = distiller.evaluate(test_images, test_labels)

# 存取準確率，results[1] 對應 SparseCategoricalAccuracy
ACCURACY['distiller student model'] = results[1]

print("Student model accuracy after distillation:", ACCURACY['distiller student model'])


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 21ms/step - sparse_categorical_accuracy: 0.3848 - distillation_loss: 0.0086 - loss: -1.2361 - student_loss: 1.5376
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 20ms/step - sparse_categorical_accuracy: 0.5452 - distillation_loss: 0.0044 - loss: -2.3227 - student_loss: 1.2861
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 21ms/step - sparse_categorical_accuracy: 0.5781 - distillation_loss: 0.0041 - loss: -2.5705 - student_loss: 1.2132
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 21ms/step - sparse_categorical_accuracy: 0.5956 - distillation_loss: 0.0042 - loss: -2.7516 - student_loss: 1.1672
Epoch 5/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 21ms/step - sparse_categorical_accuracy: 0.6095 - distillation_loss: 0.0044 - loss: -2.8618 - student_loss: 1.1296
[1m313/313[0m [32m━━━━━━━━━

In [19]:
ACCURACY

{'distiller student model': {'sparse_categorical_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.5740000009536743>}}

## 比較模型 - 從頭訓練學生

In [20]:
import tensorflow.keras as keras

ACCURACY = {}

student_scratch.compile(
    optimizer=keras.optimizers.Adam(),  # 使用 Adam 優化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # 損失函數為稀疏類別交叉熵，適用於多類別分類且模型輸出為 logits
    metrics=[keras.metrics.SparseCategoricalAccuracy()],  # 評估指標為稀疏類別準確率
)

# 顯示學生模型架構摘要
student_scratch.summary()

# 開始訓練學生模型
student_scratch.fit(
    train_images,       # 訓練資料影像 (已正規化且形狀符合模型輸入)
    train_labels,       # 訓練資料標籤
    epochs=5,           # 訓練5個迴圈（epoch）
    shuffle=True        # 每個 epoch 訓練前打亂資料，避免模型過度擬合資料順序
)

# 評估學生模型在測試集上的表現，並將準確率存入 ACCURACY 字典
_, ACCURACY['student from scratch model'] = student_scratch.evaluate(test_images, test_labels)

# 印出學生模型測試準確率
print("Student model accuracy (trained from scratch):", ACCURACY['student from scratch model'])


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 15ms/step - loss: 1.7657 - sparse_categorical_accuracy: 0.3716
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 15ms/step - loss: 1.3390 - sparse_categorical_accuracy: 0.5324
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 20ms/step - loss: 1.2461 - sparse_categorical_accuracy: 0.5651
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 21ms/step - loss: 1.1983 - sparse_categorical_accuracy: 0.5854
Epoch 5/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 21ms/step - loss: 1.1500 - sparse_categorical_accuracy: 0.6009
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 11ms/step - loss: 1.1951 - sparse_categorical_accuracy: 0.5817
Student model accuracy (trained from scratch): 0.5789999961853027


## 小結

In [21]:
ACCURACY

{'student from scratch model': 0.5789999961853027}

- 老師的準確率應會高於學生，畢竟可以採用大模型、更多的epoch等方式優化。
- 「接受知識蒸餾的學生」表現通常會優於「自己從頭開始的學生」。
- 學生的模型雖然較簡易，知識蒸餾甚至會青出於藍勝於藍。

## 參考
- https://keras.io/examples/vision/knowledge_distillation/