# MNIST digits classificaiton

## Get the dataset

In [None]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', version=1)
mnist.keys()

In [None]:
#mnist['DESCR']

In [None]:
X, y = mnist['data'], mnist['target']
X.shape, y.shape

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap='binary')
plt.axis('off')
y[0]

## Data prep

### Make targets integers

In [None]:
import numpy as np

y = y.astype(np.uint8)

### Create the train/test partitions

In [None]:
boundary = 60000

X_train, X_test, y_train, y_test = X[:boundary], X[boundary:], y[:boundary], y[boundary:]

### Create labels for a binary classifier - detect '5'

In [None]:
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

## Train a stochastic gradient descent classifier

In [None]:
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

### Make a prediction

In [None]:
sgd_clf.predict([some_digit])

## Evaluate the model

In [None]:
from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')

### Show that accuracy is a bad measure - always predict 'not 5'

In [None]:
from never_5_classifier import Never5Classifier

never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy')

### Confusion matrix

In [None]:
from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
y_train_pred

In [None]:
from sklearn.metrics import confusion_matrix

confusion_matrix(y_train_5, y_train_pred) # cols: pred NOT, pred TRUE. rows: act NOT, act TRUE

### Precision & recall

In [None]:
from sklearn.metrics import precision_score, recall_score

print('Precision\t{0:.2f}'.format(precision_score(y_train_5, y_train_pred)))
print('Recall\t\t{0:.2f}'.format(recall_score(y_train_5, y_train_pred)))

### F1 Score

In [None]:
from sklearn.metrics import f1_score

f1_score(y_train_5, y_train_pred)

### Precision/recall curve

In [None]:
from sklearn.metrics import precision_recall_curve

y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function') # return scores
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

In [None]:
from plotting_fns import plot_precision_recall_vs_threshold, plot_precision_recall

plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

In [None]:
plot_precision_recall(precisions, recalls)
plt.show()

#### Find the threshold for >= 0.9 precision & use it to make predictions

In [None]:
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]

y_train_pred_90 = (y_scores >= threshold_90_precision)

In [None]:
print(precision_score(y_train_5, y_train_pred_90))
print(recall_score(y_train_5, y_train_pred_90))

### ROC curve

In [None]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

In [None]:
from plotting_fns import plot_roc_curve

plot_roc_curve(fpr, tpr)
plt.show()

In [None]:
from sklearn.metrics import roc_auc_score

roc_auc_score(y_train_5, y_scores)

#### Repeat with a random forest classifier

In [None]:
from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42, n_estimators=10)
y_probas_forest = cross_val_predict(forest_clf, 
                                    X_train, 
                                    y_train_5, 
                                    cv=3,
                                    method='predict_proba')

y_scores_forest = y_probas_forest[:,1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)

In [None]:
plot_roc_curve(fpr_forest, tpr_forest)
plt.show()

In [None]:
roc_auc_score(y_train_5, y_scores_forest)

In [None]:
threshold = 0.5

print(precision_score(y_train_5, y_scores_forest >= threshold))
print(recall_score(y_train_5, y_scores_forest >= threshold))