In [1]:
import keras
import tensorflow as tf
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Convolution2D, MaxPooling2D, Dropout, Flatten, Dense, Input, BatchNormalization, AveragePooling2D, Reshape, Activation
from keras.layers.advanced_activations import PReLU
from keras import backend as K
from math import sqrt
import sys
sys.path.append('../CustomLayers/')
from CustomLayers import *

Using TensorFlow backend.


In [2]:
def put_kernels_on_grid (kernel, pad = 1):

  '''Visualize conv. filters as an image (mostly for the 1st layer).
  Arranges filters into a grid, with some paddings between adjacent filters.
  Args:
    kernel:            tensor of shape [Y, X, NumChannels, NumKernels]
    pad:               number of black pixels around each filter (between them)
  Return:
    Tensor of shape [1, (Y+2*pad)*grid_Y, (X+2*pad)*grid_X, NumChannels].
  '''
  # get shape of the grid. NumKernels == grid_Y * grid_X
  def factorization(n):
    for i in range(int(sqrt(float(n))), 0, -1):
      if n % i == 0:
        if i == 1: print('Who would enter a prime number of filters')
        return (i, int(n / i))
  (grid_Y, grid_X) = factorization (kernel.get_shape()[3].value)
  #print ('grid: %d = (%d, %d)' % (kernel.get_shape()[3].value, grid_Y, grid_X))

  x_min = tf.reduce_min(kernel)
  x_max = tf.reduce_max(kernel)
  kernel = (kernel - x_min) / (x_max - x_min)

  # pad X and Y
  x = tf.pad(kernel, tf.constant( [[pad,pad],[pad, pad],[0,0],[0,0]] ), mode = 'CONSTANT')

  # X and Y dimensions, w.r.t. padding
  Y = kernel.get_shape()[0] + 2 * pad
  X = kernel.get_shape()[1] + 2 * pad

  channels = kernel.get_shape()[2]

  # put NumKernels to the 1st dimension
  x = tf.transpose(x, (3, 0, 1, 2))
  # organize grid on Y axis
  x = tf.reshape(x, tf.stack([grid_X, Y * grid_Y, X, channels]))

  # switch X and Y axes
  x = tf.transpose(x, (0, 2, 1, 3))
  # organize grid on X axis
  x = tf.reshape(x, tf.stack([1, X * grid_X, Y * grid_Y, channels]))

  # back to normal order (not combining with the next step for clarity)
  x = tf.transpose(x, (2, 1, 3, 0))

  # to tf.image_summary order [batch_size, height, width, channels],
  #   where in this case batch_size == 1
  x = tf.transpose(x, (3, 0, 1, 2))

  # scaling to [0, 255] is not necessary for tensorboard
  return x

In [3]:
class TB_writer(keras.callbacks.Callback):
    def __init__(self, log_dir="",
                 histogram_freq=0,
                 batch_size=32,
                 write_graph=True,
                 write_grads=True,
                 write_images=False,
                 embeddings_freq=0,
                 embeddings_layer_names=None,
                 embeddings_metadata=None,
                 val_gen=None):
        super(TB_writer, self).__init__()
        self.log_dir = "/data/tensorflow/log/"+log_dir
        self.histogram_freq = histogram_freq
        self.write_graph = write_graph
        self.write_grads = write_grads
        self.write_images = write_images
        self.batch_size = batch_size
        self.merged = None
        self.val_gen = val_gen
        self.embeddings_freq = embeddings_freq
        self.embeddings_layer_names = embeddings_layer_names
        self.embeddings_metadata = embeddings_metadata or {}
    def set_model(self, model):
        self.model = model
        self.sess = K.get_session()
        if self.histogram_freq and self.merged is None:
            for layer in self.model.layers:
                for weight in layer.weights:
                    mapped_weight_name = weight.name.replace(':', '_')
                    if len(weight.shape) == 4:
                        kernel_split = tf.split(weight, weight.shape[3], axis=3)
                        i = 0
                        for kernel in kernel_split:
                            tf.summary.histogram(mapped_weight_name + str(i), kernel)
                            i += 1
                    else:
                        tf.summary.histogram(mapped_weight_name, weight)
                    
                    if self.write_grads:
                        grads = model.optimizer.get_gradients(model.total_loss, weight)
                        tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
                        
                    if self.write_images:
                        w_img = tf.squeeze(weight)
                        shape = K.int_shape(w_img)
                        if len(shape)==2: #dense layer
                            if shape[0] > shape[1]:
                                w_img = tf.transpose(w_img)
                                shape = K.int_shape(w_img)
                            w_img = tf.reshape(w_img, [1, shape[0], shape[1], 1])
                            w_img = tf.transpose(w_img)
                        elif len(shape) == 4: #convnet check
                            w_img = put_kernels_on_grid(w_img)
                            #if K.image_data_format() == 'channels_last':
                            #    #w_img = tf.transpose(w_img, perm[2, 0, 1])
                            #    w_img = tf.transpose(w_img, perm=[3, 2, 0, 1])
                            #    shape = K.int_shape(w_img)
                            # break kernel into black and white per channel
                            #imgs = tf.split(w_img)                            
                            #w_img = tf.reshape(w_img [shape[0], shape[1], shape[2], 1])
                            w_img = tf.transpose(w_img, perm=[3, 1, 2, 0])
                        elif len(shape)==1: #bias case
                            w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
                            w_img = tf.transpose(w_img)
                        else:
                            # maybe cant handle 3d convnnets
                            continue
                        shape = K.int_shape(w_img)
                        #print(shape)
                        assert len(shape) == 4 and shape[-1] in [1, 3, 4]
                        tf.summary.image(mapped_weight_name,w_img, max_outputs=8)
                        
                if hasattr(layer, 'output'):              
                    mapped_layer_name = layer.name.replace(':', '_')
                    if len(layer.output.shape) == 4:                        
                        output_split = tf.split(layer.output, layer.output.shape[3], axis=3)
                        i = 0
                        for output in output_split:
                            tf.summary.histogram('{}/out'.format(mapped_layer_name) + str(i), output)
                            tf.summary.image('{}/out'.format(mapped_layer_name) + str(i), output)
                            i += 1
                            if i > 16:
                                break;
                    else:
                        tf.summary.histogram('{}/out'.format(mapped_layer_name), layer.output)                        
                    
            self.merged = tf.summary.merge_all()
            if self.write_graph:
                self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
            else:
                self.writer = tf.summary.FileWriter(self.log_dir)
                
            if self.embeddings_freq:
                embeddings_layer_names = self.embeddings_layer_names

                if not embeddings_layer_names:
                    embeddings_layer_names = [layer.name for layer in self.model.layers
                                              if type(layer).__name__ == 'Embedding']

                embeddings = {layer.name: layer.weights[0]
                              for layer in self.model.layers
                              if layer.name in embeddings_layer_names}

                self.saver = tf.train.Saver(list(embeddings.values()))

                embeddings_metadata = {}

                if not isinstance(self.embeddings_metadata, str):
                    embeddings_metadata = self.embeddings_metadata
                else:
                    embeddings_metadata = {layer_name: self.embeddings_metadata
                                           for layer_name in embeddings.keys()}

                config = projector.ProjectorConfig()
                self.embeddings_ckpt_path = os.path.join(self.log_dir,
                                                         'keras_embedding.ckpt')

                for layer_name, tensor in embeddings.items():
                    embedding = config.embeddings.add()
                    embedding.tensor_name = tensor.name

                    if layer_name in embeddings_metadata:
                        embedding.metadata_path = embeddings_metadata[layer_name]

                projector.visualize_embeddings(self.writer, config)
                
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if self.val_gen and self.histogram_freq:
            if epoch % self.histogram_freq == 0:
                val_data = self.val_gen.next() + ([1], )
                tensors = (self.model.inputs +
                           self.model.targets +
                           self.model.sample_weights)
                
                if self.model.uses_learning_phase:
                    tensors += [K.learning_phase()]
                    val_data += ((True, ))                  

                assert len(val_data) == len(tensors)
                val_size = val_data[0].shape[0]
                i = 0
                while i < val_size:
                    step = min(self.batch_size, val_size - i)
                    batch_val = []
                    batch_val.append(val_data[0][i:i + step])
                    batch_val.append(val_data[1][i:i + step])
                    batch_val.append(val_data[2])
                    if self.model.uses_learning_phase:
                        batch_val.append(val_data[3])
                    feed_dict = dict(zip(tensors, batch_val))
                    result = self.sess.run([self.merged], feed_dict=feed_dict)
                    summary_str = result[0]
                    self.writer.add_summary(summary_str, epoch)
                    i += self.batch_size
                    
        if self.embeddings_freq and self.embeddings_ckpt_path:
            if epoch % self.embeddings_freq == 0:
                self.saver.save(self.sess,
                                self.embeddings_ckpt_path,
                                epoch)
                
        for name, value in logs.items():
            if name in ['batch', 'size']:
                continue
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.writer.add_summary(summary, epoch)
        self.writer.flush()

    def on_train_end(self, _):
        self.writer.close()                        

In [4]:
train_datagen = ImageDataGenerator(
        rescale = 1./255,
        shear_range = 0.2,
        zoom_range = 0.2,
        horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale=1./255)

In [5]:
train_generator = train_datagen.flow_from_directory(
        '/data/cifar/train/',
        target_size=(32,32),
        batch_size=32,
        class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
        '/data/cifar/test/',
        target_size=(32,32),
        batch_size=32,
        class_mode='categorical')

Found 50000 images belonging to 10 classes.
Found 10000 images belonging to 10 classes.


In [6]:
img_w = 32
img_h = 32
img_c = 3
inp = Input(shape=(img_w, img_h, img_c))

z = Convolution2D(32, (3,3), activation='relu')(inp)
z = MaxPooling2D(pool_size=(3,3), strides=(2,2))(z)

z = BatchNormalization()(z)
#z = Convolution2D(32, (3,3), activation='relu')(z)
#z = BinLayer()(z)
z = MultibitLayer(3)(z)
z = BinConv(128, (3,3), kernel_regularizer=BinReg(), padding='same')(z)
z = PReLU()(z)
z = MaxPooling2D(pool_size=(3,3), strides=(2,2))(z)

z = BatchNormalization()(z)
#z = BinLayer()(z)
z = MultibitLayer(3)(z)
z = BinConv(128, (3,3), kernel_regularizer=BinReg(), padding='same')(z)
z = PReLU()(z)
z = MaxPooling2D(pool_size=(3,3), strides=(2,2))(z)

z = BatchNormalization()(z)
z = Convolution2D(10, (1,1), activation='relu')(z)
z = AveragePooling2D(pool_size=(int(z.shape[1]), int(z.shape[2])))(z)
z = Reshape((10,))(z)
z = Activation('softmax')(z)

model = Model(inputs=inp, outputs=z)

In [7]:
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

In [8]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
multibit_layer_1 (MultibitLa (None, 14, 14, 32)        0         
_________________________________________________________________
bin_conv_1 (BinConv)         (None, 14, 14, 128)       36992     
_________________________________________________________________
p_re_lu_1 (PReLU)            (None, 14, 14, 128)       25088     
__________

In [9]:
tb_callback = TB_writer(histogram_freq=1, write_images=True, log_dir="cifar_test_binary", val_gen=validation_generator)
tb_callback.set_model(model)

In [10]:
model.fit_generator(
        train_generator,
        steps_per_epoch=100,
        epochs=10,
        validation_data=validation_generator,
        validation_steps=100,
        callbacks=[tb_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f7bb121d710>

In [11]:
from PIL import Image

In [12]:
def load_image( infilename ) :
    img = Image.open( infilename )
    img.load()
    data = np.asarray( img, dtype="float32" )
    data = data/255
    return data

In [13]:
image = load_image("/data/cifar/test/truck/1008_truck.png")
image = image.reshape((1,)+image.shape)

In [14]:
class_map =validation_generator.class_indices

In [15]:
guess = np.argmax(model.predict(image))

In [16]:
validation_generator.class_indices

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [17]:
cifar_labels=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [18]:
cifar_labels[guess]

'automobile'