In [19]:
import numpy as np

from model.perceptron import Perceptron

In [20]:
from pathlib import Path

perceptron = Perceptron.load_model(store_path=Path('stored_model/acc_0_96.pkl'))

In [21]:
layers = perceptron._layers
layers

[<model.layers.DenseLayer at 0x181c7c074f0>,
 <model.activation_functions.ReLU at 0x181c7c06fe0>,
 <model.layers.DenseLayer at 0x181c7c070d0>,
 <model.activation_functions.ReLU at 0x181c7c078b0>,
 <model.layers.DenseLayer at 0x181c7c07340>,
 <model.activation_functions.ReLU at 0x181c7c07190>]

In [22]:
backbone_layers = layers[0:4]
backbone_layers

[<model.layers.DenseLayer at 0x181c7c074f0>,
 <model.activation_functions.ReLU at 0x181c7c06fe0>,
 <model.layers.DenseLayer at 0x181c7c070d0>,
 <model.activation_functions.ReLU at 0x181c7c078b0>]

In [23]:
emb_perceptron = Perceptron(backbone_layers)

In [24]:
from load_mnist import mnist
from gym import split_reminder

(X_train, y_train, X_test, y_test) = mnist(path='./data')
y_train = y_train % 2
y_test = y_test % 2



In [29]:
def predict(X: np.ndarray, model: Perceptron) -> np.ndarray:
    batched_X = split_reminder(X, 1)
    res = []
    for x in batched_X:
        out = model.forward(x)  # shape 1, n
        res.append(out)
    return np.stack(res).squeeze()


train_embeddings = predict(X_train, emb_perceptron)
test_embeddings = predict(X_test, emb_perceptron)

In [33]:
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression(max_iter=500)
fitted_log_reg = log_reg.fit(train_embeddings, y_train)


In [35]:
test_predicts = fitted_log_reg.predict(test_embeddings)

In [37]:
from sklearn.metrics import classification_report
print(classification_report(y_true=y_test, y_pred=test_predicts))

              precision    recall  f1-score   support

           0       0.93      0.92      0.92      4926
           1       0.92      0.93      0.93      5074

    accuracy                           0.92     10000
   macro avg       0.92      0.92      0.92     10000
weighted avg       0.92      0.92      0.92     10000
