In [1]:
from keras.datasets import mnist
from keras.optimizers import SGD
from keras.utils import np_utils

# imports used to build the deep learning model
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras.layers import Dropout

from TNT import kernels_cluster
from keras.models import load_model

import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

Using TensorFlow backend.


## 1. Definition of a LeNet

In [2]:
def build_lenet(width, height, depth, classes, weightsPath=None):
    # Initialize the model
    model = Sequential()

    # The first set of CONV => RELU => POOL layers
    # If you need a traditional lenet, you can change the to padding="vailed"
    model.add(Conv2D(20, (5, 5), padding="same",
                     input_shape=(height, width, depth), name='CONV1'))
    model.add(Activation("relu", name='relu1'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='Pooling1'))

    # The second set of CONV => RELU => POOL layers
    model.add(Conv2D(50, (5, 5), padding="same", name='CONV2'))
    model.add(Activation("relu", name='relu2'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='Pooling2'))
    model.add(Dropout(0.025))

    # The set of FC => RELU layers
    model.add(Flatten())
    model.add(Dense(500))
    model.add(Activation("relu", name='relu3'))
    model.add(Dropout(0.025))
    
    # The softmax classifier
    model.add(Dense(classes))
    model.add(Activation("softmax", name='softmax'))

    # If a weights path is supplied, then load the weights
    if weightsPath is not None:
        model.load_weights(weightsPath)

    # Return the constructed network architecture
    return model

## 2. Preparing The MNIST Dataset

We prepare two types of input dataset. The first one is ternary type, the second one is normal type. <br>
1. Ternary input: is that every pixel is converted to -1, 0 or 1. <br>
2. Normal input: is that every pixel is in range 0~255.

In [3]:
print("[INFO] Loading the MNIST Normal dataset...")
(trainData, trainLabels), (testData, testLabels) = mnist.load_data()
trainData = trainData[:, :, :, np.newaxis]
testData = testData[:, :, :, np.newaxis]
# Rescale the data from values between [0 - 255] to [0 - 1.0]
trainData = trainData / 255.0
testData = testData / 255.0
trainLabels = np_utils.to_categorical(trainLabels, 10)
testLabels = np_utils.to_categorical(testLabels, 10)

[INFO] Loading the MNIST Normal dataset...


## 3. Initializing The LeNet

In [4]:
# Build and Compile the model
print("[INFO] Building and compiling the LeNet model...")
model = build_lenet(width=28, height=28, depth=1, classes=10)
opt = SGD(lr=0.01)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
model.summary()

[INFO] Building and compiling the LeNet model...
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
CONV1 (Conv2D)               (None, 28, 28, 20)        520       
_________________________________________________________________
relu1 (Activation)           (None, 28, 28, 20)        0         
_________________________________________________________________
Pooling1 (MaxPooling2D)      (None, 14, 14, 20)        0         
_________________________________________________________________
CONV2 (Conv2D)               (None, 14, 14, 50)        25050     
_________________________________________________________________
relu2 (Activation)           (None, 14, 14, 50)        0         
______________________________________

## 4. Training the Lenet by Ternary Input or Normal Input

### 4.1 Normal input

In [None]:
print("[INFO] Using Normal Input to train a model...")
history = model.fit(trainData,
                    trainLabels,
                    batch_size=128,
                    epochs=3,
                    validation_data=(testData, testLabels))

In [6]:
(loss, accuracy) = model.evaluate(
        testData, testLabels, batch_size=128, verbose=1)
print("[INFO] accuracy: {:.2f}%".format(accuracy * 100))

[INFO] accuracy: 95.92%


In [None]:
model.save('./model/LeNet_epoch3.model', overwrite=True)

### 4.2 Ternar Input

In [5]:
model = load_model('./model/LeNet_epoch3.model')

Instructions for updating:
Use tf.cast instead.


In [7]:
weights = model.get_weights()

In [8]:
print(weights[0])

[[[[-0.00435827 -0.01967984 -0.06710408  0.06321401 -0.06320184 -0.05127101
     0.10110379 -0.00498402  0.0738114  -0.05353865  0.00485416  0.0258638
     0.12360302 -0.03986602 -0.06357484 -0.10580987 -0.01490379  0.09963966
     0.03353032  0.02118958]]

  [[-0.0117531   0.05213843  0.09928356  0.20159774  0.01188824 -0.03079138
     0.08051449 -0.02179168  0.03302421  0.00452769  0.0532811   0.0674924
     0.18953568  0.07028069 -0.09385809 -0.0202191  -0.01219152  0.07385656
    -0.04336262  0.01123895]]

  [[-0.03193782 -0.06312631  0.05173955  0.07448452  0.04035129  0.02129952
     0.00869875  0.05544435  0.1461038   0.0507984   0.05108389  0.11168292
     0.04904866  0.07648798  0.03072315  0.06694351  0.07726569  0.04849633
     0.14006896  0.11014114]]

  [[-0.00787933 -0.0776803   0.13880318  0.25237063 -0.04807344  0.07678764
    -0.0275468  -0.10517928  0.20285228 -0.03554668 -0.03340333  0.13602601
     0.16514519  0.07247213  0.00947469 -0.01465164 -0.01946673 -0.030189

In [10]:
for i in range(0, len(weights)):
    print(i)
    weights[i] = kernels_cluster(weights[i])
model.set_weights(weights)

0
1
2
3
4
5
6
7


In [11]:
print(weights[0])

[[[[ 0.          0.         -0.07270797  0.          0.         -0.08300267
     0.12073883  0.          0.         -0.0710777   0.          0.
     0.20875083  0.         -0.07865113 -0.10580776  0.          0.06845199
     0.          0.        ]]

  [[ 0.          0.06960116  0.08913977  0.18933053  0.          0.
     0.12073883  0.          0.          0.          0.09644656  0.13038068
     0.20875083  0.08985013 -0.07865113  0.          0.          0.06845199
     0.          0.        ]]

  [[ 0.         -0.07142154  0.08913977  0.          0.          0.          0.
     0.          0.22044756  0.08296047  0.09644656  0.13038068  0.
     0.08985013  0.          0.          0.07848413  0.06845199  0.1259973
     0.0942094 ]]

  [[ 0.         -0.07142154  0.08913977  0.18933053  0.          0.0848879
     0.         -0.10517718  0.22044756  0.          0.          0.13038068
     0.20875083  0.08985013  0.          0.          0.          0.          0.
     0.0942094 ]]

  [[-0

In [12]:
(loss, accuracy) = model.evaluate(
        testData, testLabels, batch_size=128, verbose=1)
print("[INFO] accuracy: {:.2f}%".format(accuracy * 100))

[INFO] accuracy: 93.66%


## 5. Saving The Trained Model

In [13]:
model.save('model/LeNet_ternary.model', overwrite=True)

## 6. Test MNIST Dataset
Randomly select one image from test dataset to predict the number.

In [None]:
for i in np.random.choice(np.arange(0, len(testLabels)), size=(10,)):
    # Use the model to classify the digit
    probs = model.predict(testData[np.newaxis, i])
    prediction = probs.argmax(axis=1)

    # Convert the digit data to a color image
    image = (testData[i] * 255).astype("uint8")
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    # The images are in 28x28 size. Much too small to see properly
    # So, we resize them to 280x280 for viewing
    image = cv2.resize(image, (280, 280), interpolation=cv2.INTER_LINEAR)

    # Add the predicted value on to the image
    cv2.putText(image, str(prediction[0]), (20, 40),
                cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 0), 1)

    # Show the image and prediction
    print("[INFO] Predicted: {}, Actual: {}".format(
        prediction[0], np.argmax(testLabels[i])))
    cv2.imshow("Digit", image)
    cv2.waitKey(0)

cv2.destroyAllWindows()