I tried writing a ML Pipeline for the MNIST dataset

In [2]:
import numpy as np
import pandas as pd
import matplotlib as mlb
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import confusion_matrix, classification_report

# Step 01: Bringing in the data from openml
mnist = fetch_openml("mnist_784", version = 1)

# Step 02: Splitting the features and labels
X, y = mnist["data"], mnist["target"]

# Step 03: Splitting data for training and testing
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

# Step 04: Creation of CustomTRansformer (Used Binarizing)
class CustomTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, threshold):
        self.threshold = threshold
    def fit(self, X, y):
        return self
    def transform(self, X):
        return (X > self.threshold).astype(np.float32)

# Step 05: Creation of a Pipeline for Scaling, CustomTransformer, SGDClassifier
mnist_pipeline = Pipeline([
    ("binarizer", CustomTransformer(threshold = 100)),
    ("scaler", StandardScaler()),
    ("model", SGDClassifier(random_state = 42))
])

# Step 06: Training the Pipeline
mnist_pipeline.fit(X_train, y_train)

# Step 07: Prediction
y_pred = mnist_pipeline.predict(X_test)
print("Classification Report: ")
print(classification_report(y_test, y_pred))

Classification Report: 
              precision    recall  f1-score   support

           0       0.96      0.95      0.96       980
           1       0.99      0.96      0.97      1135
           2       0.94      0.86      0.90      1032
           3       0.92      0.87      0.89      1010
           4       0.94      0.86      0.90       982
           5       0.90      0.81      0.86       892
           6       0.94      0.92      0.93       958
           7       0.96      0.89      0.93      1028
           8       0.59      0.92      0.72       974
           9       0.92      0.84      0.88      1009

    accuracy                           0.89     10000
   macro avg       0.91      0.89      0.89     10000
weighted avg       0.91      0.89      0.89     10000

