In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from skimage import io, color, exposure, transform
from sklearn.model_selection import train_test_split
import os
import glob
import h5py

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, model_from_json
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D

from keras.optimizers import SGD
import keras.utils as np_utils
from keras.callbacks import LearningRateScheduler, ModelCheckpoint

from matplotlib import pyplot as plt


%matplotlib inline
NUM_CLASSES = 43
IMG_SIZE = 48

Using TensorFlow backend.


In [2]:
import os
from pathlib import PurePath

def preprocess_img(img):
    hsv = color.rgb2hsv(img)
    hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
    img = color.hsv2rgb(hsv)
    
    # crop image at center point according to min size
    min_side = min(img.shape[:-1])
    center = img.shape[0] //2, img.shape[1] // 2
    img = img[center[0] - min_side // 2:center[0] + min_side // 2,
           center[1] - min_side // 2:center[1] + min_side // 2, :]
    
    # resize 
    img = transform.resize(img, (IMG_SIZE, IMG_SIZE))
    
    return img

# get class label
def get_class(img_path):
    return int(PurePath(img_path).parts[-2])

### Preprocess Data

In [None]:
try:
    with h5py.File('X.h5') as hf:
        X, Y = hf['imgs'][:], hf['labels'][:]
    print("Load images from X.h5")
    
except(IOError, OSError, KeyError):
    print('Error in reading X.h5 processing all images!')
    ROOT_DIR = '....'
    imgs = []
    labels = []
    all_img_paths = glob.glob(os.path.join(ROOT_DIR, '*/*.ppm'))
    np.random.shuffle(all_img_paths)
    for img_path in all_img_paths:
        try:
            img = preprocess_img(io.imread(img_path))
            label = get_class(img_path)
            imgs.append(img)
            labels.append(label)
            
            if len(imgs) % 1000 == 0:
                print('Processd {}/{}'.format(len(imgs), len(all_img_paths)))
        except(IOError, OSError):
            print('Missed', img_path)
            pass
        
    X = np.array(imgs, dtype='float32')
    Y = np.eye(NUM_CLASSES, dtype='uint8')[labels] # one-hot
    
    # save data and labels
    with h5py.File('X.h5', 'w') as hf:
        hf.create_dataset('imgs', data=X)
        hf.create_dataset('labels', data=Y)

### CNN Model

In [None]:
def cnn_model():
    model = Sequential()
    
    model.add(Conv2D(32, (3, 3), padding='same', activation='relu',
                     input_shape=(IMG_SIZE, IMG_SIZE)))
    model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    
    model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
    model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    
    model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
    model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASSES, activation='softmax'))
    
    return model

model = cnn_model()

model.summary()

In [None]:
lr = 0.01
sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy'])

In [None]:
def lr_schedule(epoch):
    return lr * (0.1 * int(epoch / 10))

batch_size = 32
nb_epoch = 30

history = model.fit(X, Y,
                    batch_size=batch_size,
                    epochs=nb_epoch,
                    validation_split=0.2,
                    shuffle=True,
                    callbacks=[LearningRateScheduler(lr_schedule),
                               ModelCheckpoint('model.h5', save_best_only=True)],
                    )