<a href="https://colab.research.google.com/github/sp7412/colab/blob/master/distilling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [27]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

def softmax_sparse_categorical_crossentropy(labels, logits):
  softmaxed = tf.keras.backend.softmax(logits)

  return tf.keras.losses.sparse_categorical_crossentropy(labels, softmaxed)

In [64]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [65]:
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [66]:
num_training_examples = len(list(ds_train))
ds_train_iter = ds_train.as_numpy_iterator()
ds_train_images, ds_train_labels = zip(*[ds_train_iter.next() for i in range(num_training_examples)])

In [67]:
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [68]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [69]:
control_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  #tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

#control_model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
control_model.compile('adam', softmax_sparse_categorical_crossentropy, ['accuracy'])
control_model.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_6 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_14 (Batc (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_14 (Dropout)         (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_15 (Batc (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_15 (Dropout)         (None, 6, 6, 64)         

In [70]:
control_model.fit(ds_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f908c19c160>

In [34]:
test_loss, test_acc = control_model.evaluate(ds_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 0.13998368382453918
Test Accuracy: 0.9577999711036682


In [35]:
teacher_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  #tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

#teacher_model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
teacher_model.compile('adam', softmax_sparse_categorical_crossentropy, ['accuracy'])
teacher_model.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_5 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_11 (Batc (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_11 (Dropout)         (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_12 (Batc (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_12 (Dropout)         (None, 6, 6, 64)         

In [36]:
teacher_model.fit(ds_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f9085fede80>

In [37]:
test_loss, test_acc = teacher_model.evaluate(ds_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 0.04419132322072983
Test Accuracy: 0.9861000180244446


In [52]:
num_elements = 0
for element in ds_train:
    num_elements += 1
print(num_elements)

469


In [55]:
ds_train

<DatasetV1Adapter shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>

In [60]:
ds_train_iter = ds_train.as_numpy_iterator()
ds_train_images, ds_train_labels = zip(*[ds_train_iter.next() for i in range(100)])

In [63]:
ds_train_images[0].shape

(128, 28, 28, 1)

In [54]:
soft_labels = teacher_model.predict(ds_train_images)

ValueError: ignored

In [40]:
soft_labels

array([[-2.3333142 , -5.709156  , -2.7729857 , ..., -1.7052712 ,
        -0.297174  ,  3.5314016 ],
       [ 0.02953346, -2.5621014 , -1.8467373 , ...,  1.6847906 ,
        -1.2954066 ,  5.130977  ],
       [14.182169  , -2.634373  , -1.6142161 , ..., -0.47589996,
        -1.1354713 , -2.3049874 ],
       ...,
       [ 3.7612774 , -1.5997813 , -0.6424896 , ..., -6.354689  ,
         0.5985186 , -1.8125602 ],
       [ 0.0957149 ,  0.68847346,  0.5353254 , ..., -6.249394  ,
        -2.955457  , -1.6919029 ],
       [-4.5036664 , -1.8361857 , -1.5385014 , ..., -5.6113906 ,
        -2.6136067 ,  0.74718034]], dtype=float32)

In [None]:
temperature = 3
afterwards_temperature = 1

def temperature_softmax(logits):
  soft_logits = tf.keras.backend.exp(logits / temperature)
  return soft_logits / tf.keras.backend.sum(soft_logits, axis=-1, keepdims=True) / afterwards_temperature

def distillation_loss(labels, logits):
  labels = temperature_softmax(labels)
  logits = temperature_softmax(logits)

  return -tf.keras.backend.mean(labels * tf.keras.backend.log(logits))

In [None]:
student_model = tf.keras.models.Sequential([
  tf.keras.Input([28, 28]), 
  tf.keras.layers.Reshape([28, 28, 1]), 
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Conv2D(64, 3, 2, activation='relu'), 
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(), 
  # tf.keras.layers.Dense(10, activation='softmax')])
  tf.keras.layers.Dense(10, activation=None)])

student_model.compile('adam', distillation_loss, ['accuracy'])
student_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_2 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 13, 13, 64)        640       
_________________________________________________________________
batch_normalization_5 (Batch (None, 13, 13, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 13, 13, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 6, 6, 64)          36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 6, 6, 64)          256       
_________________________________________________________________
dropout_6 (Dropout)          (None, 6, 6, 64)         

In [None]:
student_model.fit(ds_train, soft_labels, epochs=30)

ValueError: ignored