In [35]:
from sklearn.datasets import fetch_mldata
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import classification_report

In [36]:
if __name__ == "__main__":
    data = fetch_mldata('MNIST original', data_home='data/mnist')
    X, y = data.data, data.target
    X = X/255.0*2 - 1
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    pipeline = Pipeline([
        ('clf', SVC(kernel='rbf', gamma=0.01, C=100))
    ])
    print(X_train.shape)
    parameters = {
        'clf__gamma': (0.01, 0.03, 0.1, 0.3, 1),
        'clf__C': (0.1, 0.3, 1, 3, 10, 30),
    }
    grid_search = GridSearchCV(pipeline, parameters, n_jobs=2,
                              verbose=1, scoring='accuracy')
    grid_search.fit(X_train[:10000], y_train[:10000])
    print('Best score: %0.3f' % grid_search.best_score_)
    print('Best parameters set:')
    best_parameters = grid_search.best_estimator_.get_params()
    for param_name in sorted(parameters.keys()):
        print('\t%s: %r' % (param_name, best_parameters[param_name]))
    predictions = grid_search.predict(X_test)
    print(classification_report(y_test, predictions))



(52500, 784)
Fitting 3 folds for each of 30 candidates, totalling 90 fits


[Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done  46 tasks      | elapsed: 41.5min
[Parallel(n_jobs=2)]: Done  90 out of  90 | elapsed: 81.8min finished


Best score: 0.965
Best parameters set:
	clf__C: 3
	clf__gamma: 0.01
              precision    recall  f1-score   support

         0.0       0.98      0.99      0.98      1760
         1.0       0.98      0.98      0.98      1989
         2.0       0.95      0.97      0.96      1721
         3.0       0.96      0.96      0.96      1807
         4.0       0.96      0.97      0.97      1713
         5.0       0.96      0.96      0.96      1535
         6.0       0.98      0.98      0.98      1718
         7.0       0.98      0.96      0.97      1773
         8.0       0.96      0.96      0.96      1721
         9.0       0.96      0.95      0.96      1763

   micro avg       0.97      0.97      0.97     17500
   macro avg       0.97      0.97      0.97     17500
weighted avg       0.97      0.97      0.97     17500

