In [1]:
import pandas as pd
import wandb
import keras
from wandb.keras import WandbCallback
import shutil
import tensorflow as tf
import os
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
from keras.layers.core import Dense, Activation
from keras.optimizers import Adam
from keras.metrics import categorical_crossentropy
from keras.preprocessing import image
from keras.models import Model
from keras.applications import imagenet_utils
from keras.layers import Dense,GlobalAveragePooling2D
from keras.applications import MobileNet
from keras.applications.mobilenet import preprocess_input
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

Using TensorFlow backend.


#### Setup network architecture

In [2]:
# lr: float >= 0. Learning rate.
# beta_1: float, 0 < beta < 1. Generally close to 1.
# beta_2: float, 0 < beta < 1. Generally close to 1.
# epsilon: float >= 0. Fuzz factor.
# decay: float >= 0. Learning rate decay over each update.
##Default params
# lr=0.001,
# beta_1=0.9,
# beta_2=0.999,
# epsilon=1e-08,
# decay=0.0,

hyper_params = dict(
    image_size = 224,
    num_layers_frozen = 60,
    dense_1_size = 1024,
    dense_2_size = 1024,
    dense_3_size = 512,
    batch_size = 32,
#     steps_per_epoch = 200,
    #Adam Optimizer
    lr=0.001,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-08,
    decay=0.000,
    epochs = 5
)

wandb.init(project="car_nanodegree_capstone", config=hyper_params)

config = wandb.config

dataBaseFolder = '../../../../data/'
# datasdcnd = dataBaseFolder + 'dataset-sdcnd-capstone/data/real_training_data/'
dataTL = dataBaseFolder + 'tl_engineer5/'

base_model = keras.applications.mobilenet.MobileNet(input_shape=(config.image_size,config.image_size,3), 
                                                    include_top=False, weights='imagenet')

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(config.dense_1_size,activation='relu')(x) 
#we add dense layers so that the model can learn more complex functions and classify for better results.
x=Dense(config.dense_2_size,activation='relu')(x) #dense layer 2
x=Dense(config.dense_3_size,activation='relu')(x) #dense layer 3
preds=Dense(4,activation='softmax')(x) #final layer with softmax activation

model=Model(inputs=base_model.input,outputs=preds)

for layer in model.layers[:config.num_layers_frozen]:
    layer.trainable=False

for layer in model.layers[config.num_layers_frozen:]:
    layer.trainable=True

[34m[1mwandb[0m: [32m[41mERROR[0m Not authenticated.  Copy a key from https://app.wandb.ai/authorize


API Key: ········


In [3]:
train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
    dataTL+'test',
    target_size=(config.image_size, config.image_size),
    batch_size=1,
)

train_generator = train_datagen.flow_from_directory(
        dataTL+'train',
        target_size=(config.image_size, config.image_size),
        batch_size=config.batch_size,
        class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
        dataTL+'val',
        target_size=(config.image_size, config.image_size),
        batch_size=config.batch_size,
        class_mode='categorical')

adam_opt = Adam(lr=config.lr, beta_1=config.beta_1, beta_2=config.beta_2, epsilon=config.epsilon, decay=config.decay)

model.compile(optimizer=adam_opt,loss='categorical_crossentropy',metrics=['accuracy'])

model.fit_generator(
        train_generator,
        epochs=config.epochs,
        steps_per_epoch=3521 // config.batch_size,
        validation_data=validation_generator,
        validation_steps = 755//config.batch_size,
        callbacks=[WandbCallback()])


model.save(os.path.join(wandb.run.dir, "model_70_frozen.h5"))

saver = tf.train.Saver()
saver.save(K.get_session(), '/tmp/keras_model.ckpt')

Found 627 images belonging to 4 classes.
Found 2917 images belonging to 4 classes.
Found 625 images belonging to 4 classes.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


'/tmp/keras_model.ckpt'

In [4]:
scores = model.evaluate_generator(generator=test_generator, steps=755//config.batch_size, workers=1)
print('Accuracy: ', scores[1])

filenames = test_generator.filenames
nb_samples = len(filenames)
pred = model.predict_generator(test_generator, steps=nb_samples, verbose=1, workers=1)
predicted_class_indices = np.argmax(pred, axis=1)

#Confution Matrix and Classification Report
print('Confusion Matrix')
print(confusion_matrix(test_generator.classes, predicted_class_indices))

# labels = (train_generator.class_indices)
# labels = dict((v,k) for k,v in labels.items())
# predictions = [labels[k] for k in predicted_class_indices]

print('Classification Report')
print(classification_report(test_generator.classes, predicted_class_indices))

# filenames=test_generator.filenames
# results=pd.DataFrame({"Filename":filenames,
#                       "Predictions":predictions})

# results.to_csv(dataBaseFolder + 'results_baseline.csv',index=False)

print(train_generator.class_indices)
print(validation_generator.class_indices)

('Accuracy: ', 1.0)
Confusion Matrix
[[139 127  25]
 [140 127  14]
 [ 32  18   5]]
Classification Report
              precision    recall  f1-score   support

           1       0.45      0.48      0.46       291
           2       0.47      0.45      0.46       281
           3       0.11      0.09      0.10        55

   micro avg       0.43      0.43      0.43       627
   macro avg       0.34      0.34      0.34       627
weighted avg       0.43      0.43      0.43       627

{'unknown': 2, 'green': 0, 'yellow': 3, 'red': 1}
{'unknown': 2, 'green': 0, 'yellow': 3, 'red': 1}


In [5]:
import cv2

In [6]:
def predict_img(model, fname):
    cv_image = cv2.imread(fname)
    cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
    cv_image = cv2.resize(cv_image, (224, 224))
    cv_image = cv_image/255.
    image_expanded = np.expand_dims(cv_image, axis=0)
    pred = model.predict(image_expanded)
    print(pred)
    return np.argmax(pred, axis=1)

In [7]:
pred_class_arr = []
for img_name in test_generator.filenames:
    pred_class = predict_img(model, os.path.join(dataTL+'test', img_name))
    print(img_name, pred_class)
    pred_class_arr.append(pred_class[0])

[[  0.00000000e+00   1.00000000e+00   3.22403991e-15   1.74287634e-14]]
('red/out00206.png', array([1]))
[[  2.88677349e-36   1.00000000e+00   4.44169685e-13   9.81996689e-12]]
('red/out05356.png', array([1]))
[[  2.36794242e-32   1.00000000e+00   3.03005780e-11   4.57675634e-11]]
('red/out08114.png', array([1]))
[[  5.18157424e-16   9.99979258e-01   1.05593508e-05   1.01708301e-05]]
('red/out08100.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   1.21139142e-13   7.14323688e-14]]
('red/out07582.png', array([1]))
[[  4.72432033e-36   1.00000000e+00   2.73278991e-12   2.28201659e-12]]
('red/out01440.png', array([1]))
[[  5.15104494e-35   1.00000000e+00   6.27911091e-12   5.16922442e-12]]
('red/out07964.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   7.96806604e-14   4.54001772e-13]]
('red/out04672.png', array([1]))
[[  1.96408690e-16   9.99982953e-01   1.25689548e-05   4.55014242e-06]]
('red/out05236.png', array([1]))
[[  1.38055857e-35   1.00000000e+00   1.51137056e-12   

[[  6.40524005e-23   1.00000000e+00   4.96133481e-08   5.54719719e-08]]
('red/out05122.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   4.40233177e-14   3.89434226e-13]]
('red/out00066.png', array([1]))
[[  6.36241736e-34   1.00000000e+00   1.98481179e-11   8.87078865e-12]]
('red/out00714.png', array([1]))
[[  3.20785939e-37   1.00000000e+00   2.17745201e-13   4.40132045e-12]]
('red/out06826.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   8.43743601e-16   1.98305575e-15]]
('red/out01622.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   6.74112323e-14   8.26822335e-14]]
('red/out03784.png', array([1]))
[[  1.39996079e-18   9.99997139e-01   1.25616850e-06   1.52701659e-06]]
('red/out05120.png', array([1]))
[[  6.62813251e-22   9.99999762e-01   7.66109665e-08   1.35621022e-07]]
('red/out03586.png', array([1]))
[[  1.92473301e-10   9.98520553e-01   8.07594217e-04   6.71836373e-04]]
('red/out05732.png', array([1]))
[[  6.28546939e-32   1.00000000e+00   3.69585196e-11   

[[  0.00000000e+00   1.00000000e+00   9.21784960e-15   1.22959301e-13]]
('red/out07104.png', array([1]))
[[  2.30812540e-33   1.00000000e+00   6.39831590e-12   4.55016511e-11]]
('red/out03410.png', array([1]))
[[  1.22877260e-20   9.99997973e-01   2.01640569e-06   5.46547838e-08]]
('red/out04828.png', array([1]))
[[  3.59562110e-37   1.00000000e+00   2.84428568e-13   3.58486092e-12]]
('red/out00482.png', array([1]))
[[  2.38247181e-15   9.99966860e-01   1.78330756e-05   1.52821667e-05]]
('red/out06780.png', array([1]))
[[  9.61971523e-07   2.26612583e-01   7.67460644e-01   5.92576060e-03]]
('red/out01776.png', array([2]))
[[  2.51024056e-32   1.00000000e+00   2.42244003e-11   6.30752672e-11]]
('red/out08752.png', array([1]))
[[  1.94305190e-32   1.00000000e+00   3.18553725e-11   4.34251038e-11]]
('red/out04626.png', array([1]))
[[  2.84304329e-32   1.00000000e+00   4.29467815e-11   4.67218139e-11]]
('red/out02096.png', array([1]))
[[  4.47607141e-36   1.00000000e+00   9.41653331e-13   

[[  1.18381149e-30   1.00000000e+00   3.49366716e-11   5.56214075e-10]]
('red/out06864.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   5.49717062e-18   4.59402414e-17]]
('red/out09202.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   1.45770175e-16   1.01434599e-15]]
('red/out00224.png', array([1]))
[[  6.94686301e-29   1.00000000e+00   2.49045812e-10   9.97243066e-10]]
('red/out03088.png', array([1]))
[[  1.18370316e-30   1.00000000e+00   3.49687362e-11   5.55666957e-10]]
('red/out06872.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   4.64832606e-14   1.45142009e-13]]
('red/out04254.png', array([1]))
[[  0.00000000e+00   1.00000000e+00   8.04089122e-14   4.55566628e-13]]
('red/out09016.png', array([1]))
[[  1.21294601e-34   1.00000000e+00   3.23812071e-12   1.34077948e-11]]
('red/out01258.png', array([1]))
[[  1.10135511e-32   1.00000000e+00   1.58358326e-11   5.02262537e-11]]
('red/out08254.png', array([1]))
[[  3.28463329e-37   1.00000000e+00   4.50172370e-13   

[[  4.78490011e-24   9.45086231e-06   9.99990463e-01   2.69225993e-08]]
('unknown/out06066.png', array([2]))
[[  1.50490193e-17   4.42191260e-04   9.99553025e-01   4.70444002e-06]]
('unknown/out07436.png', array([2]))
[[  9.97908716e-37   5.09631093e-08   1.00000000e+00   4.17371026e-13]]
('unknown/out00370.png', array([2]))
[[  0.00000000e+00   1.86473437e-09   1.00000000e+00   1.39011473e-14]]
('unknown/out03490.png', array([2]))
[[  1.62035088e-30   3.18300408e-07   9.99999642e-01   1.29291730e-10]]
('unknown/out07184.png', array([2]))
[[  5.02887696e-15   7.16125302e-04   9.99249279e-01   3.45732296e-05]]
('unknown/out07806.png', array([2]))
[[  3.58796414e-14   1.83731935e-03   9.98096764e-01   6.59487705e-05]]
('unknown/out01136.png', array([2]))
[[  1.49725387e-15   5.40642010e-04   9.99438941e-01   2.03422969e-05]]
('unknown/out00758.png', array([2]))
[[  7.36421410e-11   6.80649467e-03   9.92631197e-01   5.62286063e-04]]
('unknown/out03292.png', array([2]))
[[  7.94446523e-29 

[[  7.59523100e-05   8.76097023e-01   9.85386148e-02   2.52884682e-02]]
('unknown/out07518.png', array([1]))
[[  3.86835537e-20   4.18476411e-05   9.99957204e-01   8.87834460e-07]]
('unknown/out01828.png', array([2]))
[[  1.43157584e-17   4.82138887e-04   9.99512553e-01   5.32163767e-06]]
('unknown/out01196.png', array([2]))
[[  1.00287248e-19   6.60511432e-05   9.99933124e-01   8.17514263e-07]]
('unknown/out06362.png', array([2]))
[[  5.95429174e-06   2.03931957e-01   7.80317783e-01   1.57443602e-02]]
('unknown/out05680.png', array([2]))
[[  3.69269310e-06   9.02977526e-01   8.88762847e-02   8.14244431e-03]]
('unknown/out02716.png', array([1]))
[[  3.96597699e-14   2.20197276e-03   9.97738004e-01   5.99983869e-05]]
('unknown/out05904.png', array([2]))
[[  1.05184871e-36   1.87878708e-08   1.00000000e+00   1.30037788e-12]]
('unknown/out01948.png', array([2]))
[[  4.17517375e-11   7.48608727e-03   9.92085993e-01   4.27967811e-04]]
('unknown/out07486.png', array([2]))
[[  1.26523409e-24 

[[  7.31193101e-13   2.09822506e-03   9.97750580e-01   1.51216620e-04]]
('unknown/out06380.png', array([2]))
[[  2.66678687e-35   2.17341203e-08   1.00000000e+00   8.43384911e-12]]
('unknown/out00108.png', array([2]))
[[  3.61500806e-32   1.02515841e-07   9.99999881e-01   9.49370929e-11]]
('unknown/out00120.png', array([2]))
[[  0.00000000e+00   5.25373633e-11   1.00000000e+00   2.47181320e-17]]
('unknown/out08030.png', array([2]))
[[  0.00000000e+00   1.65634866e-08   1.00000000e+00   5.96650889e-14]]
('unknown/out05528.png', array([2]))
[[  1.90865636e-25   3.42374824e-06   9.99996543e-01   1.75906969e-08]]
('unknown/out03158.png', array([2]))
[[ 0.00197235  0.29003522  0.56661183  0.14138064]]
('unknown/out06550.png', array([2]))
[[  0.00000000e+00   1.36667859e-08   1.00000000e+00   4.63633682e-14]]
('unknown/out07666.png', array([2]))
[[  1.49687801e-37   6.46530696e-09   1.00000000e+00   1.67511231e-12]]
('unknown/out00096.png', array([2]))
[[  1.55570622e-36   2.97568103e-08   1

[[  2.22041907e-36   2.10712585e-08   1.00000000e+00   1.29902924e-12]]
('unknown/out03488.png', array([2]))
[[  8.56063335e-23   1.55925627e-05   9.99984264e-01   1.43860234e-07]]
('unknown/out00142.png', array([2]))
[[  6.15210934e-27   1.85288843e-06   9.99998093e-01   4.44375647e-09]]
('unknown/out01314.png', array([2]))
[[  1.01263800e-19   5.70233387e-05   9.99941826e-01   1.09740222e-06]]
('unknown/out06686.png', array([2]))
[[  3.80517137e-20   5.40143483e-05   9.99945402e-01   6.23852543e-07]]
('unknown/out02386.png', array([2]))
[[  2.30936440e-18   1.28279236e-04   9.99869108e-01   2.63961738e-06]]
('unknown/out02384.png', array([2]))
[[  2.59402265e-21   4.55271038e-05   9.99954343e-01   1.69226695e-07]]
('unknown/out03930.png', array([2]))
[[  3.61200410e-21   3.18277234e-05   9.99967933e-01   2.85544957e-07]]
('unknown/out02390.png', array([2]))
[[  1.17122783e-12   9.95587230e-01   4.39550634e-03   1.72205782e-05]]
('unknown/out00168.png', array([1]))
[[  0.00000000e+00 

In [8]:
print('Classification Report')
print(classification_report(test_generator.classes, pred_class_arr))

Classification Report
              precision    recall  f1-score   support

           1       0.94      0.99      0.96       291
           2       0.98      0.96      0.97       281
           3       1.00      0.84      0.91        55

   micro avg       0.96      0.96      0.96       627
   macro avg       0.97      0.93      0.95       627
weighted avg       0.96      0.96      0.96       627

