<a href="https://colab.research.google.com/github/s-c-soma/AdvanceDeeplearning-CMPE-297/blob/master/ExtraCredit/ExtraCredit_Assignment_3_Part_C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 3 Part C

## Imports

In [None]:
import tensorflow as tf

from tensorflow.keras import models
from tensorflow.keras import layers

tf.random.set_seed(666)

## Load Data

In [None]:
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

X_train.shape, X_test.shape, y_train.shape, y_test.shape

((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

## Preprocessing

In [None]:
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

## Model

In [None]:
# Define utility function for building a basic shallow Convnet 
def get_teacher_model():
    model = models.Sequential()
    model.add(layers.Conv2D(16, (5, 5), activation="relu",
        input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

In [None]:
# Define loass function and optimizer
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [None]:
# Prepare TF dataset
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

# Train the teacher model
teacher_model = get_teacher_model()
teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=["accuracy"])
teacher_model.fit(train_ds,
                  validation_data=test_ds,
                  epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f93604dc828>

In [None]:
# Evaluate and serialize
print("Test accuracy: {:.2f}".format(teacher_model.evaluate(test_ds)[1]*100))
teacher_model.save_weights("teacher_model.h5")

Test accuracy: 90.20


In [None]:
# Student model utility
def get_student_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(48, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

In [None]:
# Credits: https://github.com/google-research/simclr/blob/master/colabs/distillation_self_training.ipynb
def get_kd_loss(student_logits, teacher_logits, temperature=0.5):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.compat.v1.losses.softmax_cross_entropy(
        teacher_probs, student_logits / temperature, temperature**2)
    return kd_loss

In [None]:
# Model, optimizer
student_model = get_student_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

In [None]:
# Train utils
@tf.function
def model_train(images, labels, teacher_model, 
                student_model, optimizer, temperature):
    teacher_logits = teacher_model(images)

    with tf.GradientTape() as tape:
        student_logits = student_model(images)
        loss = get_kd_loss(student_logits, teacher_logits, temperature)
    
    gradients = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

    train_loss(loss)
    train_acc(labels, tf.nn.softmax(student_logits))

In [None]:
# Validation utils
@tf.function
def model_validate(images, labels, teacher_model, 
                   student_model, temperature):
    teacher_logits = teacher_model(images)

    student_logits = student_model(images)
    loss = get_kd_loss(student_logits, teacher_logits, temperature)

    valid_loss(loss)
    valid_acc(labels, tf.nn.softmax(student_logits))

In [None]:
# Tie everything together
def train_model(epochs, teacher_model, student_model, optimizer, temperature=0.5):
    for epoch in range(epochs):
        for (images, labels) in train_ds:
            model_train(images, labels, teacher_model, student_model, optimizer, temperature)

        for (images, labels) in test_ds:
            model_validate(images, labels, teacher_model, student_model, temperature)
            
        (loss, acc) = train_loss.result(), train_acc.result()
        (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
        
        train_loss.reset_states(), train_acc.reset_states()
        valid_loss.reset_states(), valid_acc.reset_states()
        
        template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
        print (template.format(epoch+1,
                            loss,
                            acc,
                            val_loss,
                            val_acc))
        
    
    return teacher_model, student_model

In [None]:
_, student_model = train_model(10, teacher_model, student_model, optimizer)

Epoch 1, loss: 0.116, acc: 0.816, val_loss: 0.097, val_acc: 0.825
Epoch 2, loss: 0.091, acc: 0.848, val_loss: 0.091, val_acc: 0.838
Epoch 3, loss: 0.086, acc: 0.853, val_loss: 0.088, val_acc: 0.841
Epoch 4, loss: 0.084, acc: 0.857, val_loss: 0.086, val_acc: 0.846
Epoch 5, loss: 0.082, acc: 0.858, val_loss: 0.089, val_acc: 0.838
Epoch 6, loss: 0.081, acc: 0.861, val_loss: 0.085, val_acc: 0.848
Epoch 7, loss: 0.080, acc: 0.862, val_loss: 0.088, val_acc: 0.840
Epoch 8, loss: 0.079, acc: 0.863, val_loss: 0.092, val_acc: 0.838
Epoch 9, loss: 0.078, acc: 0.864, val_loss: 0.085, val_acc: 0.850
Epoch 10, loss: 0.078, acc: 0.864, val_loss: 0.086, val_acc: 0.845


In [None]:
# Serialize
student_model.save_weights("student_model.h5")

In [None]:
# Investigate the sizes
!ls -lh *.h5

-rw-r--r-- 1 root root 163K Aug 31 07:47 student_model.h5
-rw-r--r-- 1 root root 335K Aug 31 07:44 teacher_model.h5


In [None]:
teacher_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 24, 24, 16)        416       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 8, 8, 32)          12832     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 4, 32)          0         
_________________________________________________________________
dropout (Dropout)            (None, 4, 4, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               6

In [None]:
student_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 48)                37680     
_________________________________________________________________
dense_5 (Dense)              (None, 10)                490       
Total params: 38,170
Trainable params: 38,170
Non-trainable params: 0
_________________________________________________________________
