# sklearn intro

For more detailed introduction, see e.g. https://scikit-learn.org/stable/tutorial/basic/tutorial.html

In [1]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA

## Loading tabular dataset

In [2]:
data = load_iris()

In [3]:
X = data.data
y = data.target

pd.DataFrame(X, columns=data.feature_names).head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2


In [4]:
data.target_names

array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

In [5]:
np.unique(y)

array([0, 1, 2])

## Classification

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [7]:
clf = KNeighborsClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy_score(y_test, y_pred)

0.9736842105263158

In [8]:
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy_score(y_test, y_pred)

1.0

## Tuning hyper-parameters

In [9]:
clf = KNeighborsClassifier()
param_grid = {'n_neighbors': [1, 2, 3, 4, 5, 10, 20]}
gscv = GridSearchCV(clf, param_grid, cv=5)
gscv.fit(X_train, y_train)
y_pred = gscv.predict(X_test)
accuracy_score(y_test, y_pred)



0.9736842105263158

In [10]:
gscv.best_estimator_

KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=3, p=2,
           weights='uniform')

## Composite models and pipelining

In [11]:
pipe = Pipeline([
    ('transform', PCA()),
    ('classify', KNeighborsClassifier())
])
pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)
accuracy_score(y_test, y_pred)

0.9736842105263158