# Matriz de Confusão Multi

### Importar dataset Mnist

In [None]:
from sklearn.datasets import fetch_openml
import numpy as np

mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
mnist.target = mnist.target.astype(np.int8)

X, y = mnist["data"], mnist["target"]

np.save('mnistX', X)
np.save('mnisty', y)

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

### Ajustando um Classificador

In [None]:
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(loss='log_loss', max_iter=5, tol=-np.infty, random_state=42)
sgd_clf.fit(X_train, y_train)

# Validação Cruzada no Multiclasses
print("Validação Cruzada - 3-folds: ")
print(cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy"))

### Reescalando os dados para treinar o modelo

In [None]:
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
print("Validação Cruzada reescalado - 3-folds: ")
print(cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy"))

### Matriz de Confusão

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
print(conf_mx)

plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

### Função para plotar a Matriz de Confusão

In [None]:
def plot_confusion_matrix(matrix):
    """Apenas para colorir a matriz"""
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)

### Plotar a Matriz

In [None]:
plot_confusion_matrix(conf_mx)
plt.show()

### Matriz de Erros

In [None]:
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()