In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from scipy import ndimage
import math
import ast #to easily read out class text file that contains some unknwn syntax.
import scipy   #to upscale the image
import matplotlib.pyplot as plt
import cv2     
from keras.applications.resnet import ResNet50, preprocess_input
from keras.models import Model   
from PIL import Image



In [3]:
(training_images, training_labels) , (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

def preprocess_image_input(input_images):
  input_images = input_images.astype('float32')
  output_ims = tf.keras.applications.resnet50.preprocess_input(input_images)
  return output_ims

train_X = preprocess_image_input(training_images)
test_X = preprocess_image_input(test_images)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
def feature_extractor(inputs):

  feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')(inputs)
 
  return feature_extractor



def classifier(inputs):
    x= tf.keras.layers.Conv2D(2048, (1, 1), strides=(1, 1), padding="same")(inputs)
    x= tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)  
    x = tf.keras.layers.Dense(10,name="classification")(x)
    return x



def final_model(inputs):

    resize = tf.keras.layers.UpSampling2D(size=(7,7))(inputs)

    resnet_feature_extractor = feature_extractor(resize)
    
    
    classification_output = classifier(resnet_feature_extractor)

    return classification_output




def define_compile_model():
  inputs = tf.keras.layers.Input(shape=(32,32,3))
  
  classification_output = final_model(inputs) 
  model = tf.keras.Model(inputs=inputs, outputs = classification_output)
 
  model.compile(optimizer='SGD', 
                loss='sparse_categorical_crossentropy',
                metrics = ['accuracy'])
  
  return model


resnet50 = define_compile_model()

resnet50.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 up_sampling2d (UpSampling2  (None, 224, 224, 3)       0         
 D)                                                              
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 conv2d (Conv2D)             (None, 7, 7, 2048)        4196352   
                                                                 
 re_lu (ReLU)                (None, 7, 7, 2048)        0         
                                                  

In [5]:
resnet50.compile(
    optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

In [6]:
resnet50.fit(train_X, training_labels, epochs=3, validation_data = (test_X, test_labels), batch_size=64)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x78c7f022b880>

In [25]:
#proportionality maintained student
student_bl = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 3)),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.GlobalAveragePooling2D(),
        layers.Dense(10),
    ],
    name="student_bl",
)

student_bl.summary()

Model: "student_bl"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_28 (Conv2D)          (None, 16, 16, 128)       3584      
                                                                 
 leaky_re_lu_18 (LeakyReLU)  (None, 16, 16, 128)       0         
                                                                 
 max_pooling2d_18 (MaxPooli  (None, 16, 16, 128)       0         
 ng2D)                                                           
                                                                 
 conv2d_29 (Conv2D)          (None, 8, 8, 64)          73792     
                                                                 
 leaky_re_lu_19 (LeakyReLU)  (None, 8, 8, 64)          0         
                                                                 
 max_pooling2d_19 (MaxPooli  (None, 8, 8, 64)          0         
 ng2D)                                                  

In [8]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=2,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6679999828338623, 1.0420880317687988]

In [24]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=3,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.671500027179718, 1.1355122327804565]

In [10]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=4,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6491000056266785, 1.3188858032226562]

In [12]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=5,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6525999903678894, 1.1079679727554321]

In [14]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=6,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.666100025177002, 1.7322781085968018]

In [16]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=7,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6625000238418579, 1.2223527431488037]

In [18]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=8,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.652999997138977, 1.3971123695373535]

In [26]:
class Distiller_bl(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
       
        x, y = data

  
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
           
            student_predictions = self.student(x, training=True)

            
            student_loss = self.student_loss_fn(y, student_predictions)

            
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

       
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
        self.compiled_metrics.update_state(y, student_predictions)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
       
        x, y = data

      
        y_prediction = self.student(x, training=False)

       
        student_loss = self.student_loss_fn(y, y_prediction)

      
        self.compiled_metrics.update_state(y, y_prediction)

       
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

distiller = Distiller_bl(student=student_bl, teacher=resnet50)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=9,
)


distiller.fit(train_X, training_labels, epochs=10)


distiller.evaluate(test_X, test_labels)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6535000205039978, 0.8826004266738892]