In [1]:
#import relevant packages and libraies
import os
import re
import cv2
import numpy as np
from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")

import tensorflow.keras as keras
from keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, MaxPooling2D, Activation, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras import models, layers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from keras.models import Model
from tensorflow.keras.applications import mobilenet_v2, vgg16
from keras.preprocessing.image import ImageDataGenerator


from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay


In [2]:
situations = {
                'c0': 'Safe driving', 
                'c1': 'Texting - right', 
                'c2': 'Talking on the phone - right', 
                'c3': 'Texting - left', 
                'c4': 'Talking on the phone - left', 
                'c5': 'Operating the radio', 
                'c6': 'Drinking', 
                'c7': 'Reaching behind', 
                'c8': 'Hair and makeup', 
                'c9': 'Talking to passenger'
}

In [3]:
X = []
y = []
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
base_path = '../input/state-farm-distracted-driver-detection/imgs/train/'

for i, target in enumerate(classes):
    
    files = os.listdir(base_path+target)
    
    print(f'we are now in {target} class')
    
    for file in files:
        
        # load the image
        img = load_img(base_path+target + '/' + file, target_size=(224, 224))
        
        #convert it to an array
        img_array = np.array(img)
              
        # append the array to X
        X.append(img_array)
        
        # append the numeric target to y
        y.append(i) 
        
print('finished')

In [4]:
X = np.array(X)
y = np.array(y)

In [5]:
# shuffle the data
shuffler = np.random.permutation(len(X))
X = X[shuffler]
y = y[shuffler]

In [6]:
# train-test split
from sklearn.model_selection import train_test_split
X_train_sample, X_test_sample, y_train_sample, y_test_sample = train_test_split(X, y, test_size=0.90)

In [7]:
X_train,X_test, y_train, y_test = train_test_split(X_train_sample,y_train_sample, test_size=0.20)

In [8]:
#keep a copy of y test as the actual number labels for further process
y_test_true = y_test.copy()

In [9]:
#one hot encode the labels
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [10]:
#scaling
X_train = X_train.astype('float32')/255

In [11]:
X_test = X_test.astype('float32')/255

In [12]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

# VGG16

In [13]:
#clear session 
K.clear_session()

In [14]:
# number of possible label values
nb_classes = 10

## Defining the input

from keras.layers import Input
vgg16_input = Input(shape = (224,224,3), name = 'Image_input')

#designing the layers and build model
base_model = vgg16.VGG16(weights='imagenet', include_top=False, input_tensor = vgg16_input)

output = base_model(vgg16_input)

x = Flatten(name='flatten')(output)

from keras.models import Model

x = Dense(nb_classes, activation='softmax', name='predictions')(x)

model = Model(inputs=vgg16_input, outputs=x)



In [15]:
#check the summary:
model.summary()

In [16]:
#compile the model
model.compile(optimizer=SGD(learning_rate=0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [17]:
datagen = ImageDataGenerator(
                            height_shift_range=0.5,
                            width_shift_range = 0.5,
                            zoom_range = 0.5,
                            rotation_range=30
                            )
#datagen.fit(X_train)
data_generator = datagen.flow(X_train, y_train, batch_size = 32)

# Fits the model on batches with real-time data augmentation:
history = model.fit_generator(
                               data_generator,
                               steps_per_epoch=len(X_train)/32, 
                               epochs = 100,
                               validation_data = (X_test, y_test)
                              )

In [18]:
model.save('model_vgg16.h5')

In [19]:
ypred = model.predict(X_test)

In [20]:
from pylab import rcParams
rcParams['figure.figsize'] = 12,20
# have a look at the data
for i in range(10):
    plt.subplot(5, 2, i+1)
    plt.imshow(X_test[i])
    plt.axis('off')
    plt.title(situations[classes[np.argmax(ypred[i])]]) 

# Examine the model

In [21]:
#plot accuracy curve
plt.figure(figsize=(8,6))
plt.plot(history.history['accuracy'], label='training accuracy')
plt.plot(history.history['val_accuracy'], label= 'validation accuracy')

plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend()


In [22]:
# Plot loss
plt.figure(figsize=(8,6))
plt.plot(history.history['loss'], label='training loss')
plt.plot(history.history['val_loss'], label='validation loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

In [23]:
true_list =[]
for i in range(len(y_test_true)):
    if np.argmax(ypred[i])==y_test_true[i]:
        true_list.append(i)
        
probability = len(true_list) / len(y_test_true)
round(probability,3)

In [24]:
#plot confusion matrix
cm = confusion_matrix(y_true=y_test_true, y_pred=np.argmax(ypred, axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                             display_labels=classes)
fig, ax = plt.subplots(figsize=(10,10))

disp.plot(ax=ax);

In [25]:
from pylab import rcParams
rcParams['figure.figsize'] = 12,20
# have a look at the data
for i in range(10):
    plt.subplot(5, 2, i+1)
    plt.imshow(X_test[i])
    plt.axis('off')
    plt.title(situations[classes[np.argmax(ypred[i])]]) 

In [26]:
history.history

In [28]:
history_x = {'loss': [2.4607796669006348,
  2.3907833099365234,
  2.3278446197509766,
  2.316110849380493,
  2.304399013519287,
  2.3039824962615967,
  2.303065299987793,
  2.303892135620117,
  2.3007566928863525,
  2.298985719680786,
  2.3001739978790283,
  2.298713207244873,
  2.294557809829712,
  2.293314218521118,
  2.2928340435028076,
  2.295558214187622,
  2.2924935817718506,
  2.2932448387145996,
  2.2913436889648438,
  2.2913832664489746,
  2.2821907997131348,
  2.284726142883301,
  2.2772276401519775,
  2.2643542289733887,
  2.2722830772399902,
  2.2618632316589355,
  2.2663521766662598,
  2.253303289413452,
  2.2394728660583496,
  2.238931655883789,
  2.222933292388916,
  2.207343101501465,
  2.258585214614868,
  2.201601982116699,
  2.2439563274383545,
  2.201796054840088,
  2.15795636177063,
  2.1523356437683105,
  2.1179394721984863,
  2.082789421081543,
  2.0781619548797607,
  2.3045718669891357,
  2.28593373298645,
  2.2581980228424072,
  2.190247058868408,
  2.249967575073242,
  2.0971126556396484,
  2.0062291622161865,
  1.9929314851760864,
  2.0936944484710693,
  1.9494249820709229,
  1.8810077905654907,
  1.9242205619812012,
  1.8638781309127808,
  1.7865209579467773,
  1.9194906949996948,
  1.7297580242156982,
  1.6443147659301758,
  1.9422836303710938,
  2.1592438220977783,
  2.012373447418213,
  1.7579797506332397,
  1.6619956493377686,
  1.579310655593872,
  1.6545639038085938,
  1.5547351837158203,
  1.4524853229522705,
  1.9497839212417603,
  1.9452341794967651,
  1.9497634172439575,
  1.6748918294906616,
  1.5026040077209473,
  1.353031039237976,
  1.3291347026824951,
  1.25484037399292,
  1.440677523612976,
  1.3158560991287231,
  1.2334985733032227,
  1.1790199279785156,
  1.420926809310913,
  1.1463016271591187,
  1.0834308862686157,
  1.8734995126724243,
  1.4823644161224365,
  1.6103508472442627,
  1.140376329421997,
  1.036054253578186,
  1.006953477859497,
  0.997637152671814,
  1.035748839378357,
  1.1083489656448364,
  1.0205765962600708,
  0.9905422329902649,
  0.9197331666946411,
  0.9390596747398376,
  0.8545984029769897,
  0.8548481464385986,
  0.8665569424629211,
  0.8290330171585083,
  0.8535390496253967],
 'accuracy': [0.10819855332374573,
  0.10262130200862885,
  0.10819855332374573,
  0.09760178625583649,
  0.11321806907653809,
  0.10931400209665298,
  0.1193530410528183,
  0.10987172275781631,
  0.11489124596118927,
  0.11991076171398163,
  0.1193530410528183,
  0.11210262030363083,
  0.11656441539525986,
  0.12158393859863281,
  0.11600669473409653,
  0.11879531294107437,
  0.12381483614444733,
  0.12437255680561066,
  0.12102621048688889,
  0.1154489666223526,
  0.13552704453468323,
  0.13273842632770538,
  0.13050752878189087,
  0.15504740178585052,
  0.1505856066942215,
  0.1583937555551529,
  0.1388733983039856,
  0.1583937555551529,
  0.17735639214515686,
  0.16341327130794525,
  0.17177914083003998,
  0.1818181872367859,
  0.176240935921669,
  0.19631901383399963,
  0.18851087987422943,
  0.1818181872367859,
  0.20468488335609436,
  0.20468488335609436,
  0.2191857248544693,
  0.2152816504240036,
  0.24037925899028778,
  0.1450083702802658,
  0.1271611750125885,
  0.16397099196910858,
  0.199665367603302,
  0.16731734573841095,
  0.22922475636005402,
  0.2660345733165741,
  0.2643614113330841,
  0.23145565390586853,
  0.3034021258354187,
  0.30284440517425537,
  0.31790295243263245,
  0.3195761442184448,
  0.3535973131656647,
  0.32571110129356384,
  0.36698269844055176,
  0.39598438143730164,
  0.3329615294933319,
  0.1974344700574875,
  0.26324597001075745,
  0.36307865381240845,
  0.3742331266403198,
  0.41606247425079346,
  0.3931957483291626,
  0.43558281660079956,
  0.4634690582752228,
  0.37869492173194885,
  0.3089793622493744,
  0.37479084730148315,
  0.41606247425079346,
  0.4891243577003479,
  0.5181260704994202,
  0.5153374075889587,
  0.5527049899101257,
  0.5186837911605835,
  0.5482431650161743,
  0.5705521702766418,
  0.5744562149047852,
  0.5041829347610474,
  0.6001115441322327,
  0.6051310896873474,
  0.350808709859848,
  0.5203569531440735,
  0.450641393661499,
  0.6001115441322327,
  0.6279977560043335,
  0.6469603776931763,
  0.6542108058929443,
  0.6263245940208435,
  0.6319018602371216,
  0.6363636255264282,
  0.6480758786201477,
  0.6793084144592285,
  0.6715002655982971,
  0.7055214643478394,
  0.6965978741645813,
  0.6971555948257446,
  0.7233686447143555,
  0.6954824328422546],
 'val_loss': [2.3117504119873047,
  2.295529842376709,
  2.2930638790130615,
  2.3018295764923096,
  2.315221071243286,
  2.305750846862793,
  2.3022847175598145,
  2.290461540222168,
  2.3001959323883057,
  2.2925519943237305,
  2.294278621673584,
  2.289424419403076,
  2.291395664215088,
  2.2892117500305176,
  2.298950672149658,
  2.2895302772521973,
  2.2797904014587402,
  2.274573564529419,
  2.268775701522827,
  2.2663216590881348,
  2.260516881942749,
  2.2601258754730225,
  2.2327775955200195,
  2.221834421157837,
  2.2136905193328857,
  2.189504623413086,
  2.190242290496826,
  2.1474037170410156,
  2.1352717876434326,
  2.112584114074707,
  2.074176549911499,
  2.047581434249878,
  2.044334650039673,
  2.0180606842041016,
  2.1125004291534424,
  2.049088954925537,
  2.0020575523376465,
  1.8632450103759766,
  1.8501527309417725,
  1.7862648963928223,
  1.752921223640442,
  2.283445358276367,
  2.234009265899658,
  2.1142940521240234,
  2.271036386489868,
  1.9693834781646729,
  1.7619189023971558,
  1.7086586952209473,
  1.4741259813308716,
  1.7992892265319824,
  1.5042226314544678,
  1.3308629989624023,
  1.249316692352295,
  1.1922492980957031,
  2.5117533206939697,
  1.70883047580719,
  1.01026451587677,
  1.075566053390503,
  2.187572479248047,
  1.8548033237457275,
  1.321243405342102,
  1.063166856765747,
  0.8948644995689392,
  0.8719474077224731,
  0.7885288596153259,
  1.2964619398117065,
  0.6337472200393677,
  1.7389307022094727,
  1.2501674890518188,
  1.416735291481018,
  0.882928729057312,
  0.7061488628387451,
  0.6019383668899536,
  0.6775381565093994,
  9.334821701049805,
  0.5419487953186035,
  0.6479842066764832,
  0.4723030924797058,
  0.4795842468738556,
  0.6171011924743652,
  0.532192587852478,
  0.4076594114303589,
  0.5839046239852905,
  0.456885427236557,
  0.4336778223514557,
  0.4040141701698303,
  0.4401548504829407,
  0.3698643147945404,
  0.38164281845092773,
  0.3456398546695709,
  0.3351221978664398,
  0.3040621280670166,
  0.33465075492858887,
  0.33619505167007446,
  0.26880717277526855,
  0.2457936704158783,
  0.2577451169490814,
  0.2678213119506836,
  0.2568076550960541,
  0.25144484639167786],
 'val_accuracy': [0.10244988650083542,
  0.1358574628829956,
  0.11135857552289963,
  0.08685968816280365,
  0.08685968816280365,
  0.1269487738609314,
  0.1314031183719635,
  0.10467705875635147,
  0.09131403267383575,
  0.10022271424531937,
  0.11135857552289963,
  0.1269487738609314,
  0.09576837718486786,
  0.09576837718486786,
  0.10244988650083542,
  0.13363029062747955,
  0.11135857552289963,
  0.1314031183719635,
  0.13808463513851166,
  0.15367482602596283,
  0.1314031183719635,
  0.11581292003393173,
  0.1737193763256073,
  0.17594654858112335,
  0.18485523760318756,
  0.20044542849063873,
  0.17149220407009125,
  0.2204899787902832,
  0.2583518922328949,
  0.24721603095531464,
  0.2717149257659912,
  0.2249443233013153,
  0.2494432032108307,
  0.24053451418876648,
  0.1737193763256073,
  0.2249443233013153,
  0.24053451418876648,
  0.320712685585022,
  0.2828507721424103,
  0.28953230381011963,
  0.3318485617637634,
  0.11135857552289963,
  0.18040089309215546,
  0.23385301232337952,
  0.09131403267383575,
  0.26057907938957214,
  0.3697104752063751,
  0.3741648197174072,
  0.4699331820011139,
  0.35857459902763367,
  0.4454343020915985,
  0.5412026643753052,
  0.5545657277107239,
  0.5434298515319824,
  0.30289533734321594,
  0.3986636996269226,
  0.6080178022384644,
  0.6080178022384644,
  0.22717149555683136,
  0.3095768392086029,
  0.5412026643753052,
  0.6035634875297546,
  0.692650318145752,
  0.6681514382362366,
  0.7171491980552673,
  0.5679287314414978,
  0.781737208366394,
  0.37193763256073,
  0.5968819856643677,
  0.5567928552627563,
  0.7082405090332031,
  0.7461024522781372,
  0.795100212097168,
  0.7728285193443298,
  0.2761692702770233,
  0.8017817139625549,
  0.7728285193443298,
  0.8262805938720703,
  0.8129175901412964,
  0.7706013321876526,
  0.8062360882759094,
  0.8374164700508118,
  0.7795100212097168,
  0.839643657207489,
  0.8485523462295532,
  0.8775055408477783,
  0.8574610352516174,
  0.8685969114303589,
  0.839643657207489,
  0.8708240389823914,
  0.8819599151611328,
  0.8953229188919067,
  0.8819599151611328,
  0.8953229188919067,
  0.9109131693840027,
  0.9198217988014221,
  0.9220489859580994,
  0.8997772932052612,
  0.9175946712493896,
  0.9086859822273254]}