In [1]:
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import load_model
from keras.utils import np_utils
import numpy as np
from numpy.linalg import norm
import argparse
import cv2
import matplotlib.pyplot as plt

Using TensorFlow backend.


## 1. MNIST Dataset

In [2]:
print("[INFO] Loading the MNIST dataset...")
(trainData, trainLabels), (testData, testLabels) = mnist.load_data()
trainData = np.multiply(1./255, trainData)
testData = np.multiply(1./255, testData)
trainLabels = np_utils.to_categorical(trainLabels, 10)
testLabels = np_utils.to_categorical(testLabels, 10)
trainData = trainData[:, :, :, np.newaxis]
testData = testData[:, :, :, np.newaxis]
print('[INFO] training size {}'.format(trainData.shape))
print('[INFO] test size {}'.format(testData.shape))

[INFO] Loading the MNIST dataset...
[INFO] training size (60000, 28, 28, 1)
[INFO] test size (10000, 28, 28, 1)


# 2. Load models

In [4]:
# original model
model_full_precision = load_model('./models/LeNet5_original.model')

In [5]:
# ternary model
model_ternary = load_model('./models/LeNet5_ternary.model')

In [6]:
# ternary model with scalar-tuning
model_ternary_scalar = load_model('./models/LeNet5_ternary_scalar.model')

## 3. Accuracy Comparsion

In [7]:
(loss, accuracy) = model_full_precision.evaluate(
        testData, testLabels, batch_size=128, verbose=1)
print("[INFO] Full precision model accuracy is: {:.2f}%".format(accuracy * 100))

[INFO] Full precision model accuracy is: 99.22%


In [8]:
(loss, accuracy) = model_ternary.evaluate(
        testData, testLabels, batch_size=128, verbose=1)
print("[INFO] Ternary weights model accuracy is: {:.2f}%".format(accuracy * 100))

[INFO] Ternary weights model accuracy is: 98.70%


In [9]:
(loss, accuracy) = model_ternary_scalar.evaluate(
        testData, testLabels, batch_size=128, verbose=1)
print("[INFO] Ternary scalar-tuning model accuracy is: {:.2f}%".format(accuracy * 100))

[INFO] Ternary scalar-tuning model accuracy is: 99.10%


## 4. Weights Comparsion

In [10]:
#  original weights
weights_full_precision = model_full_precision.get_weights()
# ternary weights
weights_ternary = model_ternary.get_weights()
# scalar-tuning weights
weights_ternary_scalar = model_ternary_scalar.get_weights()

### 4.1 convolution layer

In [11]:
print(weights_full_precision[0][:, :, 0, 1])

[[ 0.07828661  0.07991946  0.03333475  0.12516291  0.07970043]
 [ 0.04642316 -0.01765191  0.03420295  0.09401795 -0.01672778]
 [ 0.04315952  0.01074686  0.073874    0.0927763   0.06956593]
 [-0.02776762 -0.10874511  0.03971153  0.05412877  0.06917713]
 [-0.15361701 -0.15454859 -0.13040234 -0.00376264 -0.06776509]]


In [12]:
print(weights_ternary[0][:, :, 0, 1])

[[ 1.  1.  0.  1.  1.]
 [ 0.  0.  0.  1.  0.]
 [ 0.  0.  1.  1.  1.]
 [ 0. -1.  0.  1.  1.]
 [-1. -1. -1.  0. -1.]]


In [13]:
print(weights_ternary_scalar[0][:, :, 0, 1])

[[ 0.08166043  0.08166043  0.          0.08166043  0.08166043]
 [ 0.          0.          0.          0.08166043  0.        ]
 [ 0.          0.          0.08166043  0.08166043  0.08166043]
 [ 0.         -0.12301452  0.          0.08166043  0.08166043]
 [-0.12301452 -0.12301452 -0.12301452  0.         -0.12301452]]


### 4.2 bias

In [14]:
# oringinal
print(weights_full_precision[1])

[ 0.01191805  0.00381361  0.00187092 -0.01358425  0.00168212  0.00929875
  0.00508314  0.00097986 -0.00902778  0.00424224 -0.01501222  0.00823121
  0.0305446   0.00609368  0.00457578  0.00307899  0.00067997  0.0003576
  0.00522174  0.00142306  0.00158354  0.00296732  0.01257685 -0.01380011
 -0.01120229 -0.02553516 -0.01151572 -0.00151149  0.01593739  0.00877792
 -0.00499236  0.00272836]


In [15]:
# ternary
print(weights_ternary[1])

[ 1.  0.  0. -1.  0.  1.  0.  0. -1.  0. -1.  1.  1.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  1. -1. -1. -1. -1.  0.  1.  1.  0.  0.]


In [16]:
# ternary with scalar-tuning
print(weights_ternary_scalar[1])

[ 0.01389772  0.          0.         -0.01423954  0.          0.01389772
  0.          0.         -0.01423954  0.         -0.01423954  0.01389772
  0.01389772  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.01389772 -0.01423954 -0.01423954
 -0.01423954 -0.01423954  0.          0.01389772  0.01389772  0.          0.        ]


### 4.3 Cosine similarity Comparsion

In [17]:
def inner(a_, t_):
    return np.dot(a_.reshape(1, -1), t_.reshape(-1, 1)) / (norm(a_) * norm(t_) + 0.00001)

In [18]:
cosine = inner(weights_full_precision[0][:, :, 0, 1], weights_ternary[0][:, :, 0, 1])
print("The cosine similarity between full precision weights and ternary weights is *{}*".format(cosine[0][0]))

The cosine similarity between full precision weights and ternary weights is *0.9233752489089966*


In [19]:
cosine = inner(weights_full_precision[0][:, :, 0, 1], weights_ternary_scalar[0][:, :, 0, 1])
print("The cosine similarity between full precision weights and ternary scalar-tuning weights is *{}*".format(cosine[0][0]))

The cosine similarity between full precision weights and ternary scalar-tuning weights is *0.9423828125*


## 5. Test MNIST Dataset

In [None]:
for i in np.random.choice(np.arange(0, len(testLabels)), size=(10,)):
    # Use the model to classify the digit
    # Please try models:　 model_full_precision, model_ternary, and model_ternary_scalar
    probs = model_ternary.predict(testData[np.newaxis, i])
    prediction = probs.argmax(axis=1)

    # Convert the digit data to a color image
    image = (testData[i] * 255).astype("uint8")
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    # The images are in 28x28 size. Much too small to see properly
    # So, we resize them to 280x280 for viewing
    image = cv2.resize(image, (280, 280), interpolation=cv2.INTER_LINEAR)

    # Add the predicted value on to the image
    cv2.putText(image, str(prediction[0]), (20, 40),
                cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 0), 1)

    # Show the image and prediction
    print("[INFO] Predicted: {}, Actual: {}".format(
        prediction[0], np.argmax(testLabels[i])))
    cv2.imshow("Digit", image)
    cv2.waitKey(0)

cv2.destroyAllWindows()

[INFO] Predicted: 2, Actual: 2
