# Baseline: SVM Classification

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import os
import glob
from collections import defaultdict
from PIL import Image

In [2]:
def PIE_path():
    """
    output: mainpath, where PIE exists.
    """
    return os.getcwd() + '/PIE'

def load_data(mainpath):
    """"
    input:
    mainpath: str, PIE dir, which contains 68 people (photos).
    --------------------------------------------------------------------
    output:
    label: list, name
    data: list, pictures.
    """
    label = []
    data = []
    for i in range(1, 69):
        pictures_path = glob.glob(mainpath + '/' + str(i) + '/*.jpg')
        for p in  pictures_path:
            ### pictures to matrix
            data.append(np.asarray(Image.open(p)).flatten())
            label.append(str(i))
    return label, data


## data load in

In [3]:
label, dataset = load_data(PIE_path())


## train test split

In [4]:
X_train, X_test, Y_train, Y_test = train_test_split(
   dataset, label, test_size=0.3, random_state=1024, shuffle=True 
)


## apply LDA to train data 
### apply the trained LDA model to test data to avoid data leckage.
### n_components may affect...

In [5]:
# set LDA model
LDA_model = LinearDiscriminantAnalysis(n_components=10) # 10,20,30,40,....
# train LDA model
LDA_model.fit(X_train, Y_train)
# transform
lda_X_train = LDA_model.transform(X_train)
lda_X_test = LDA_model.transform(X_test)


## SVM: train and predict 
#### Girdsearch with kernel = "rbf"

In [7]:
for penalty in [0.01,0.1,1,10,100]:
    for ga in [0.01,0.1,0.5,1.,2,4]:
        # set svm model
        SVM_model = svm.SVC(C=penalty, kernel="rbf", gamma=ga)
        # train svm model
        SVM_model.fit(lda_X_train, Y_train)
        # predcit on test data
        pre = SVM_model.predict(lda_X_test)

        # accuracy
        acc_num = 0
        for res, tareget in zip(pre, Y_test):
            if res == tareget:
                acc_num += 1

        print(r"hyperparameters with C={} and gamma={}.The accuracy is {}".format(penalty, ga, acc_num/len(pre)))


hyperparameters with C=0.01 and gamma=0.01.The accuracy is 0.20853764061147967
hyperparameters with C=0.01 and gamma=0.1.The accuracy is 0.010095183155465821
hyperparameters with C=0.01 and gamma=0.5.The accuracy is 0.00980674935102394
hyperparameters with C=0.01 and gamma=1.0.The accuracy is 0.00980674935102394
hyperparameters with C=0.01 and gamma=2.The accuracy is 0.00980674935102394
hyperparameters with C=0.01 and gamma=4.The accuracy is 0.00980674935102394
hyperparameters with C=0.1 and gamma=0.01.The accuracy is 0.8520334583213153
hyperparameters with C=0.1 and gamma=0.1.The accuracy is 0.8658782809345256
hyperparameters with C=0.1 and gamma=0.5.The accuracy is 0.02913181424862994
hyperparameters with C=0.1 and gamma=1.0.The accuracy is 0.00980674935102394
hyperparameters with C=0.1 and gamma=2.The accuracy is 0.00980674935102394
hyperparameters with C=0.1 and gamma=4.The accuracy is 0.00980674935102394
hyperparameters with C=1 and gamma=0.01.The accuracy is 0.8918373233342948
hy