In [3]:
import os
import sys
import logging
import cv2
import json
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from ast import literal_eval
from tqdm import tqdm_notebook as tqdm
import numpy as np

from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

import pickle

%matplotlib inline
!jupyter nbextension enable --py --sys-prefix widgetsnbextension

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('TrainClassifier')

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from vehicle_detector import extract_features
from vehicle_detector.utils import dataset

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [11]:
def get_params_from_dataset_name(dataset_name):
    
    tokens = dataset_name.split('_')
    offset = 0
    if tokens[0] == 'unfiltered':
        offset = 1
#     print(tokens)
    params = {
        'C': literal_eval(tokens[3 + offset]),
        'gamma': tokens[5 + offset] if tokens[5 + offset]=='auto' else literal_eval(tokens[5 + offset]),
        'color_hist': literal_eval(tokens[7 + offset]),
        'orientations': literal_eval(tokens[9 + offset]),
        'pixels_per_cell': literal_eval(tokens[11 + offset]),
        'cells_per_block': literal_eval(tokens[13 + offset]),
        'hog_color_space': os.path.splitext(tokens[15 + offset])[0]
    }
    return params

In [13]:
dataset_dir = '../data/datasets'
datasets = list(filter(lambda x: x.endswith('.hdf5') == True, os.listdir(dataset_dir)))
models_dir = '../data/models'

models = []
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

for this_dataset in datasets:
    X, y = dataset.load_dataset(os.path.join(dataset_dir, this_dataset), 'carnd_p5')
    rand_state = np.random.randint(0,100)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
                                                        random_state=rand_state)
    X_scaler = StandardScaler().fit(X_train)
    
    # Apply the scaler to X
    X_train = X_scaler.transform(X_train)
    X_test = X_scaler.transform(X_test)
    params = get_params_from_dataset_name(this_dataset)
    
    clf = svm.SVC(kernel='rbf', C=params['C'], gamma=params['gamma'],
                  probability=True, random_state=rand_state)
    
    clf.fit(X_train, y_train)
    model_name = '{name}.cpickle'.format(name=os.path.splitext(this_dataset)[0])
    scaler_name = '{name}_scaler.cpickle'.format(name=os.path.splitext(this_dataset)[0])
    model_filepath = os.path.join(models_dir,model_name)
    scaler_filepath = os.path.join(models_dir, scaler_name)
    
    with open(model_filepath, 'wb') as fp:
        fp.write(pickle.dumps(clf))
    with open(scaler_filepath, 'wb') as fp:
        fp.write(pickle.dumps(X_scaler))
    
    y_pred = clf.predict(X_test)
    accuracy = round(accuracy_score(y_test, y_pred) * 100, 3)
    print(model_name, accuracy)
    
    models.append({
        'name': model_name,
        'file_path': model_filepath,
        'accuracy': accuracy,
        'scaler_file_path': scaler_filepath
    })
    
with open('../data/top_models.json', 'w') as fp:
    json.dump(models, fp)

unfiltered_A_99.125_C_10_gamma_auto_CH_False_O_10_P_(16, 16)_C_(2, 2)_CS_BGR2HSV.cpickle 98.72
unfiltered_A_99.125_C_10_gamma_auto_CH_True_O_10_P_(16, 16)_C_(4, 4)_CS_BGR2YCrCb.cpickle 99.556
A_99.0_C_10_gamma_auto_CH_True_O_10_P_(8, 8)_C_(4, 4)_CS_BGR2HSV.cpickle 99.199
A_99.5_C_10_gamma_auto_CH_True_O_12_P_(16, 16)_C_(4, 4)_CS_BGR2YCrCb.cpickle 99.545
A_99.0_C_10_gamma_auto_CH_True_O_10_P_(8, 8)_C_(4, 4)_CS_BGR2YCrCb.cpickle 99.272
A_99.125_C_10_gamma_auto_CH_False_O_10_P_(16, 16)_C_(2, 2)_CS_BGR2HSV.cpickle 98.854
unfiltered_A_99.5_C_10_gamma_auto_CH_True_O_12_P_(16, 16)_C_(4, 4)_CS_BGR2YCrCb.cpickle 99.573
unfiltered_A_99.0_C_10_gamma_auto_CH_True_O_10_P_(8, 8)_C_(4, 4)_CS_BGR2HSV.cpickle 98.908
A_99.125_C_10_gamma_auto_CH_True_O_9_P_(16, 16)_C_(4, 4)_CS_BGR2YCrCb.cpickle 99.691
unfiltered_A_99.125_C_10_gamma_auto_CH_False_O_9_P_(16, 16)_C_(2, 2)_CS_BGR2YCrCb.cpickle 98.925
unfiltered_A_99.125_C_10_gamma_auto_CH_True_O_12_P_(16, 16)_C_(2, 2)_CS_BGR2HSV.cpickle 99.266
A_99.0_C_10_ga