# Image Classification using `sklearn.svm`

In [74]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook
from sklearn import svm, metrics, datasets
from sklearn.utils import Bunch
from sklearn.model_selection import GridSearchCV, train_test_split

from skimage.io import imread
from skimage.transform import resize

### Load images in structured directory like it's sklearn sample dataset

In [154]:
def load_image_files(container_path, dimension=(400, 300)):
    image_dir = Path(container_path)
    folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]
    flat_data = []
    target = []
    for i, direc in enumerate(folders):
        for file in direc.iterdir():
            img = imread(file)
            img_resized = resize(img, dimension, anti_aliasing=True, mode='reflect')
            print(i, direc, img.shape, img_resized.shape)
            flat_data.append(img_resized.flatten()) 
           
            target.append(i)
    flat_data = np.array(flat_data)
    target = np.array(target)
    
    return [flat_data, target]
#     Bunch(data=flat_data,
#                  target=target,
#                  target_names=categories,
#                  images=images,
#                  DESCR=descr)

In [159]:
image_dataset = load_image_files("images/")


0 images/squirtle (450, 324, 3) (400, 300, 3)
0 images/squirtle (450, 324, 3) (400, 300, 3)
0 images/squirtle (450, 324, 3) (400, 300, 3)
1 images/charmander (500, 359, 3) (400, 300, 3)
1 images/charmander (500, 359, 3) (400, 300, 3)
1 images/charmander (500, 359, 3) (400, 300, 3)
2 images/bulbasaur (500, 359, 3) (400, 300, 3)
2 images/bulbasaur (500, 359, 3) (400, 300, 3)
2 images/bulbasaur (500, 359, 3) (400, 300, 3)


### Split data

In [160]:
# X_train, X_test, y_train, y_test = train_test_split(
#     image_dataset[0], image_dataset[1], test_size=0.3,random_state=109)
X_train, y_train = image_dataset[0], image_dataset[1]

### Train data with parameter optimization

In [161]:
param_grid = [
  {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
  {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
 ]
print("Training svm")
svc = svm.SVC()
print("Getting optimal param")
#clf = svm.SVC(gamma='auto')
clf = GridSearchCV(svc, param_grid,cv=3)
print("fitting model")
clf.fit(X_train, y_train)

Training svm
Getting optimal param
fitting model


GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid=[{'C': [1, 10, 100, 1000], 'kernel': ['linear']}, {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)

### Predict

In [165]:
import skimage

pred_0 = resize(np.load("squirtle_hold_short.npy")[0,:,:,:], (400,300), anti_aliasing=True, mode='reflect').flatten()
pred_2 = resize(np.load("bulbasaur_hold_short1.npy")[1,:,:,:], (400,300), anti_aliasing=True, mode='reflect').flatten()

In [134]:
X_test = np.array([pred_1, pred_2])

In [167]:
y_pred = clf.predict(pred_2.reshape(1,-1))
# y_pred_2 = clf.predict(pred_2)
print(y_pred)
# 1 - charmander
# 0 - squirtle
# 2 - bulbasaur

[0]


### Report

In [145]:
print("Classification report for - \n{}:\n{}\n".format(
    clf, metrics.classification_report(y_test, y_pred)))

print("Confusion matrix")
conf = metrics.confusion_matrix(y_test, y_pred)
acc = np.diag(conf).sum()/conf.sum()
conf, acc


Classification report for - 
GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid=[{'C': [1, 10, 100, 1000], 'kernel': ['linear']}, {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0):
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        92
           1       1.00      1.00      1.00        91
           2       1.00      1.00      1.00        92

   micro avg       1.00      1.00      1.00       275
   macro avg       1.00      1.00      1.00       275
weighted avg       1.00      1.00      1.00  

(array([[92,  0,  0],
        [ 0, 91,  0],
        [ 0,  0, 92]]), 1.0)

In [149]:
test_dataset = load_image_files("test/")

0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64, 64, 3)
0 test/squirtle (288, 432, 3) (64,

1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) (64, 64, 3)
1 test/charmander (288, 432, 3) 

2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulb

2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)
2 test/bulbasaur (288, 432, 3) (64, 64, 3)


In [150]:
X_rot_test, y_rot_test = test_dataset[0], test_dataset[1]
print("loaded")
y_rot_pred = clf.predict(X_rot_test)
print(y_rot_pred)
print("Confusion matrix")
conf_rot = metrics.confusion_matrix(y_rot_test, y_rot_pred)
acc_rot = np.diag(conf_rot).sum()/conf_rot.sum()
conf_rot, acc_rot

loaded
[1 1 1 2 2 2 2 2 2 1 1 1 1 1 2 2 1 1 2 2 1 1 1 1 1 2 2 2 2 2 2 1 1 1 1 1 1
 2 2 1 1 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 2 2 1 2 2 2 1 1 1 1 1 1 1 1 2 1 2 1 2 2 1 2 1 2 1 1 1 1 1 1 2 1 1 2 2
 2 2 1 2 2 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 0 0 0 0 0 2 0 2 0 0 0 0 2 0 0 0 0 0 0 0 2 0 0 0 0 0 1 0 2 0 2 2 0 0 0 1
 0 0 0 0 0 2 0 0 2 0 0 0 0 0 0 0 2 2 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 2 0
 0 0 0 2 0 0 0 0 0 0 0 0 2 2 2 2 1 1 1 1 1 1 1 1 2 2 2 2 0 0 2 0 2 2 2 1 1
 1 1 2 2 2 0 2 2 0 2 2 2 2 1 1 2 2 1 1 2 2 2 2 0 0 0 2 2 1 1 1 1 2 2 1 1 1
 1 2 2 0 0 2 2 2 1 1 1 2 1 1 1 1 1 1 2 2 2 2 2 2 1 1 1 1 1 1 1 1 2 2 0 2 2
 0 1 1 1 1 1 1 1 1

(array([[  0, 128,  44],
        [  0, 125,   0],
        [148,  78,  86]]), 0.3464696223316913)

In [92]:
X_rot_test, y_rot_test = test_dataset[0], test_dataset[1]
print("loaded")
y_rot_pred = clf.predict(X_rot_test)
print("Confusion matrix")
conf_rot = metrics.confusion_matrix(y_rot_test, y_rot_pred)
acc_rot = np.diag(conf_rot).sum()/conf_rot.sum()
conf_rot, acc_rot

loaded
Confusion matrix


(array([[ 31,   0,  84],
        [  0,   6, 127],
        [  0,   0, 117]]), 0.42191780821917807)