<a href="https://colab.research.google.com/github/sp7412/colab/blob/master/distilling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

In [2]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

def softmax_sparse_categorical_crossentropy(labels, logits):
  softmaxed = tf.keras.backend.softmax(logits)

  return tf.keras.losses.sparse_categorical_crossentropy(labels, softmaxed)

In [78]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [52]:
ds_info.features['image'].shape

(28, 28, 1)

In [53]:
num_training_examples = ds_info.splits['train'].num_examples
num_training_examples

60000

In [7]:
ds_train_iter = ds_train.as_numpy_iterator()
train_images, train_labels = zip(*[ds_train_iter.next() for i in range(num_training_examples)])
train_images = np.asarray(train_images)

In [127]:
train_images.shape

(60000, 28, 28, 1)

In [55]:
num_testing_examples = ds_info.splits['test'].num_examples
num_testing_examples

10000

In [79]:
ds_test_iter = ds_test.as_numpy_iterator()

In [158]:
test_images, test_labels = zip(*[ds_test_iter.next() for i in range(num_testing_examples)])
test_images = np.asarray(test_images)

In [128]:
# ds_train = ds_train.take(10)

In [12]:
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [13]:
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [14]:
control_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  #tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

#control_model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
control_model.compile('adam', softmax_sparse_categorical_crossentropy, ['accuracy'])
control_model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_2 (Batch (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_2 (Dropout)          (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_3 (Batch (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 6, 6, 64)         

In [16]:
control_model.fit(ds_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f111fecfc88>

In [17]:
test_loss, test_acc = control_model.evaluate(ds_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 0.20186038315296173
Test Accuracy: 0.9366000294685364


In [18]:
teacher_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  #tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

#teacher_model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
teacher_model.compile('adam', softmax_sparse_categorical_crossentropy, ['accuracy'])
teacher_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_2 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_4 (Batch (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_4 (Dropout)          (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_5 (Batch (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 6, 6, 64)         

In [19]:
teacher_model.fit(ds_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f111fecfb38>

In [20]:
test_loss, test_acc = teacher_model.evaluate(ds_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 0.047029588371515274
Test Accuracy: 0.9854000210762024


In [111]:
# train_images = list(map(lambda x:tf.reshape(x,(-1,28,28,1)), train_images))

In [22]:
ds_train_images = tf.data.Dataset.from_tensor_slices(list(map(lambda x:tf.reshape(x,(-1,28,28,1)), train_images)))
soft_labels = teacher_model.predict(ds_train_images,verbose=1)



In [23]:
temperature = 3
afterwards_temperature = 1

def temperature_softmax(logits):
  soft_logits = tf.keras.backend.exp(logits / temperature)
  return soft_logits / tf.keras.backend.sum(soft_logits, axis=-1, keepdims=True) / afterwards_temperature

def distillation_loss(labels, logits):
  labels = temperature_softmax(labels)
  logits = temperature_softmax(logits)

  return -tf.keras.backend.mean(labels * tf.keras.backend.log(logits))

In [24]:
student_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  # tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

student_model.compile('adam', distillation_loss, ['accuracy'])
student_model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_3 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_7 (Batch (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_7 (Dropout)          (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_8 (Batch (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_8 (Dropout)          (None, 6, 6, 64)         

In [77]:
student_model.fit(train_images, soft_labels, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f10e2e51630>

In [80]:
test_images, test_labels = zip(*[ds_test_iter.next() for i in range(num_testing_examples)])

In [81]:
test_images = tf.cast(test_images, dtype='float64')
test_labels = tf.cast(test_labels, dtype='float64')

In [82]:
test_loss, test_acc = student_model.evaluate(test_images, test_labels)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 4.21751594543457
Test Accuracy: 0.09749999642372131
