In [1]:
##------------------------------------------------------------------------
## Summary : classify hand written digits using simple 2-layer CNN
## Author  : Venkata Srinivas Vemprala
## SourceCode : https://github.com/Hvass-Labs/TensorFlow-Tutorials
##------------------------------------------------------------------------

%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
from sklearn.metrics import confusion_matrix



In [2]:
## Importing the handwritten data from tensorflow
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/',one_hot=True)

Extracting data/MNIST/train-images-idx3-ubyte.gz
Extracting data/MNIST/train-labels-idx1-ubyte.gz
Extracting data/MNIST/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/t10k-labels-idx1-ubyte.gz


In [3]:
##Declaring constants
img_size = 28
img_shape = (img_size,img_size)
img_size_flat = img_size * img_size
num_classes = 10
batch_size = 100

trueLabelsTrain = np.array([label.argmax() for label in data.train.labels])
trueLabelsTest = np.array([label.argmax() for label in data.test.labels])
trueLabelsValidation = np.array([label.argmax() for label in data.validation.labels])

In [4]:
##lets write a function to display the images using matplotlib

def plotImages(images,trueLabels,predLabels=None):
    assert len(truelabels) == len(images) == 9
    
    #create a 3x3 subplot
    (fig,axes) = plt.subplots(3,3)
    fig.subplots_adjust(hspace=0.3,wspace=0.3)
    
    for i,ax in enumerate(axes.flat):
        #plot image
        ax.imshow(images[i].reshape(img_shape),cmap='binary')
        
        if predLabels==None:
            xlabel = "True : "+str(trueLabels[i])
        else:
            xlabel = "True : "+str(trueLabels[i])+" Pred : "+str(predLabels[i])
        
        #set label and remove ticks
        ax.set_xlabel(xlabel)
        ax.set_xticks([])
        ax.set_yticks([])

In [5]:
##lets write a function to print weights
def printConfMatrix(trueLabels,predLabels):
    cm = confusion_matrix(y_true=trueLabels,y_pred=predLabels)
    print(cm)

##Function to print wrong examples    
def printExampleErrors(images,trueLabels,predLabels):
    correct = np.equal(trueLabels,predLabels)
    indexes = np.where(correct==False)
    
    incorrectImages = images[indexes]
    incorrectTrueLabels = trueLabels[indexes]
    incorrectPredLabels = predLabels[indexes]
    plotImages(incorrectImages,incorrectTrueLabels,incorrectPredLabels)

In [6]:
##lets write a function to plot conv weights
def plotConvWeights(weights,inputChannel=0):
    ##find max and min of weights
    w_max = np.max(weights)
    w_min = np.min(weights)
    
    ##lets get number of weights to plot
    numFilters = weights.shape[3]
    
    ##lets get number of images per axis
    numGrids = math.ceil(math.sqrt(numFilters))
    
    (fig,axes) = plt.subplots(numGrids,numGrids)
    
    for i,ax in enumerate(axes.flat):
        image = weights[:,:,inputChannel,i]
        ax.imshow(image,vmax=w_max,vmin=w_min,interpolation='nearest', cmap='seismic')
        ax.set_xlabel('Filter : '+str(i))
        
        #remove ticks
        ax.set_xticks([])
        ax.set_yticks([])

In [7]:
##lets write a function to create a convolution layer
##We use same padding and max pooling in window 2x2,Relu as activation function
def createCLayer(inputLayer,numInputChan,filterSize,numOutputChan,usePooling=True):
    shape = [filterSize,filterSize,numInputChan,numOutputChan]
    
    ##Declare weights and biases
    weights = tf.Variable(tf.truncated_normal(shape=shape,stddev=0.05))
    biases = tf.Variable(tf.constant(0.05,shape=[numOutputChan]))
    
    layer = tf.nn.conv2d(input=inputLayer,filter=weights,strides=[1,1,1,1],padding='SAME')
    layer = layer+biases
    
    ##Kindly note that here we are doing pooling first and then applying Relu
    ##This works as max(relu(x))==relu(max(x)) and reduces relu operations by 75%
    ##This doesn't work for other
    if usePooling:
        layer = tf.nn.max_pool(value=layer,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    
    layer = tf.nn.relu(layer)
    
    return layer,weights

In [8]:
##lets write a function to flatten
def flattenLayer(layer):
    layer_shape = layer.get_shape()
    
    layerLen = layer_shape[0]*layer_shape[1]*layer_shape[2]
    layerflat = tf.reshape(layer,[-1,layerLen])
    
    return layerflat,layerLen

In [None]:
##lets define function to create a fc layer
def 