In [33]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# To plot pretty figures
import matplotlib as mpl

# Common imports
import numpy as np

from sklearn.datasets import fetch_mldata
from sklearn.linear_model import SGDClassifier
from IPython.core.interactiveshell import InteractiveShell
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score



%matplotlib inline
InteractiveShell.ast_node_interactivity = "all"
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = ".."
CHAPTER_ID = "classification"

# 加载 MNIST 数据

In [2]:
try:
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, cache=True)
    mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
except ImportError:
    mnist = fetch_mldata('MNIST original')
mnist["data"], mnist["target"]

(array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 array([0., 0., 0., ..., 9., 9., 9.]))

In [3]:
X, y = mnist["data"], mnist["target"]
shuffle_index = np.random.permutation(60000)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

In [13]:
pd.Series(y_train).value_counts()

1.0    6742
7.0    6265
3.0    6131
2.0    5958
9.0    5949
0.0    5923
6.0    5918
8.0    5851
4.0    5842
5.0    5421
dtype: int64

# 模型训练

In [14]:
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

False    54579
True      5421
dtype: int64

In [5]:
sgd_clf = SGDClassifier(max_iter=5, tol=-np.infty, random_state=42)
sgd_clf.fit(X_train, y_train_5)

SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
       eta0=0.0, fit_intercept=True, l1_ratio=0.15,
       learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
       n_jobs=1, penalty='l2', power_t=0.5, random_state=42, shuffle=True,
       tol=-inf, verbose=0, warm_start=False)

In [6]:
cross_val_score(sgd_clf, X_train, y_train_5, cv=8, scoring="accuracy")

array([0.81522464, 0.97986935, 0.9580056 , 0.94013333, 0.95906667,
       0.96106147, 0.92759035, 0.97239632])

In [16]:
y_train_pred = cross_val_predict(sgd_clf,  X_train, y_train_5, cv=3)

In [36]:
# 行表示实际类别 列表示预测类别
# 精度 = TP /(TP + FP)
# 召回 = TP /(TP + FN)
confusion_matrix(y_train_5, y_train_pred)
precision_score(y_train_5, y_train_pred)
recall_score(y_train_5, y_train_pred)
f1_score(y_train_5, y_train_pred)

array([[47700,  6879],
       [  889,  4532]])

0.3971606344755061

0.8360081165836561

0.5384980988593155

In [59]:
xx = fetch_mldata('MNIST original')["data"]
some_digit = xx[600]
y_score = sgd_clf.decision_function([some_digit])
y_score


array([-883494.39697114])