
# k-NN Classifier


+ basic sklearn calls
+ Search over neighborhood size to boost performance


In [53]:
import numpy as np
import pylab as plt
from sklearn import svm
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier

from tqdm import tqdm
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')

In [54]:
# getdata
import pandas as pd
df=pd.read_csv('../datasets/CAD-PTSDData.csv',index_col=0)
y=df.iloc[:,0]
X=df.iloc[:,1:]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
y_train.shape



(212,)

In [55]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
clf = RandomForestClassifier(max_depth=10, class_weight='balanced',n_estimators=1000)


clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
print("Number of mislabeled points out of a total %d points : %d"% (y_test.shape[0],(y_test != y_pred).sum()))

acc=clf.score(X_test,y_test)
y_pred=clf.predict_proba(X_test)
y_pred_insample=clf.predict_proba(X_train)
fpr, tpr, thresholds = metrics.roc_curve(y_test,y_pred[:,1], pos_label=1)
fpr_insample, tpr_insample, thresholds_insample = metrics.roc_curve(y_train,y_pred_insample[:,1], pos_label=1)
auc=metrics.auc(fpr, tpr)
auc_insample=metrics.auc(fpr_insample, tpr_insample)

print('Accuracy: ',acc,'AUC oos: ',auc,'AUC in sample: ',auc_insample)

Number of mislabeled points out of a total 122 points : 18
Accuracy:  0.8524590163934426 AUC oos:  0.8770032051282052 AUC in sample:  0.9999999999999999


In [56]:
clfn = KNeighborsClassifier(n_neighbors=4)
clfn.fit(X_train,y_train)

y_pred = clfn.predict(X_test)
print("Number of mislabeled points out of a total %d points : %d"% (y_test.shape[0],(y_test != y_pred).sum()))

acc=clfn.score(X_test,y_test)
y_pred=clfn.predict_proba(X_test)
y_pred_insample=clfn.predict_proba(X_train)
fpr, tpr, thresholds = metrics.roc_curve(y_test,y_pred[:,1], pos_label=1)
fpr_insample, tpr_insample, thresholds_insample = metrics.roc_curve(y_train,y_pred_insample[:,1], pos_label=1)
auc=metrics.auc(fpr, tpr)
auc_insample=metrics.auc(fpr_insample, tpr_insample)

print('Accuracy: ',acc,'AUC oos: ',auc,'AUC in sample: ',auc_insample)

Number of mislabeled points out of a total 122 points : 18
Accuracy:  0.8524590163934426 AUC oos:  0.8481570512820513 AUC in sample:  0.9409153005464481


# search for n to beat random forest

In [57]:
Perf={}
for n in tqdm(np.arange(20,40)):
    clfn = KNeighborsClassifier(n_neighbors=n)
    clfn.fit(X_train,y_train)

    y_pred = clfn.predict(X_test)
    miss=(y_test != y_pred).sum()

    acc=clfn.score(X_test,y_test)
    y_pred=clfn.predict_proba(X_test)
    y_pred_insample=clfn.predict_proba(X_train)
    fpr, tpr, thresholds = metrics.roc_curve(y_test,y_pred[:,1], pos_label=1)
    fpr_insample, tpr_insample, thresholds_insample = metrics.roc_curve(y_train,y_pred_insample[:,1], pos_label=1)
    auc=metrics.auc(fpr, tpr)
    auc_insample=metrics.auc(fpr_insample, tpr_insample)

    Perf[auc]=(miss,n)

Perf=pd.DataFrame(Perf).transpose()
Perf.columns=['miss','n']
Perf.index.name='auc'
n=Perf.reset_index().sort_values('auc',ascending=False).head(1).n.values[0]
n

100%|██████████| 20/20 [00:01<00:00, 17.68it/s]


21

In [58]:
clfn = KNeighborsClassifier(n_neighbors=n,weights='distance')
clfn.fit(X_train,y_train)

y_pred = clfn.predict(X_test)
print("Number of mislabeled points out of a total %d points : %d"% (y_test.shape[0],(y_test != y_pred).sum()))

acc=clfn.score(X_test,y_test)
y_pred=clfn.predict_proba(X_test)
y_pred_insample=clfn.predict_proba(X_train)
fpr, tpr, thresholds = metrics.roc_curve(y_test,y_pred[:,1], pos_label=1)
fpr_insample, tpr_insample, thresholds_insample = metrics.roc_curve(y_train,y_pred_insample[:,1], pos_label=1)
auc=metrics.auc(fpr, tpr)
auc_insample=metrics.auc(fpr_insample, tpr_insample)

print('Accuracy: ',acc,'AUC oos: ',auc,'AUC in sample: ',auc_insample)

Number of mislabeled points out of a total 122 points : 18
Accuracy:  0.8524590163934426 AUC oos:  0.889022435897436 AUC in sample:  1.0
