In [1]:
import tensorflow as tf

from neuro.nn import activation, layer, losses, models, optimizer


In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0


In [3]:
model = models.Sequential(
    layer.Flatten(),
    layer.Dense(28*28, 512), activation.LeakyReLU(),
    layer.Dropout(0.25),
    layer.Dense(512, 128), activation.LeakyReLU(),
    layer.Dense(128, 16), activation.LeakyReLU(),
    layer.Dense(16, 10), activation.Softmax(),
)


In [4]:
loss = losses.SparseCategoricalCrossentropy()
optim = optimizer.Adam()


In [5]:
epochs = 100
for i in range(epochs):
    # Forward Propagation
    y_pred = model(x_train)

    # Calculation of Loss
    train_loss = loss(y_pred, y_train)
    print(f"Epoch: {i + 1}, Loss: {train_loss.numpy()}")

    # Back Propagation + Optimizing
    optim(model, loss)


Epoch: 1, Loss: 2.304018259048462
Epoch: 2, Loss: 2.3015072345733643
Epoch: 3, Loss: 2.2973036766052246
Epoch: 4, Loss: 2.2901761531829834
Epoch: 5, Loss: 2.2790720462799072
Epoch: 6, Loss: 2.2618954181671143
Epoch: 7, Loss: 2.236630439758301
Epoch: 8, Loss: 2.2013068199157715
Epoch: 9, Loss: 2.1545631885528564
Epoch: 10, Loss: 2.0953071117401123
Epoch: 11, Loss: 2.0240886211395264
Epoch: 12, Loss: 1.9435429573059082
Epoch: 13, Loss: 1.8556946516036987
Epoch: 14, Loss: 1.7639503479003906
Epoch: 15, Loss: 1.674763560295105
Epoch: 16, Loss: 1.5846307277679443
Epoch: 17, Loss: 1.4937654733657837
Epoch: 18, Loss: 1.3940019607543945
Epoch: 19, Loss: 1.2937705516815186
Epoch: 20, Loss: 1.2182345390319824
Epoch: 21, Loss: 1.1497609615325928
Epoch: 22, Loss: 1.0799413919448853
Epoch: 23, Loss: 1.0191490650177002
Epoch: 24, Loss: 0.9698159098625183
Epoch: 25, Loss: 0.9288426637649536
Epoch: 26, Loss: 0.8893264532089233
Epoch: 27, Loss: 0.8532916903495789
Epoch: 28, Loss: 0.8222622275352478
Epoc

In [6]:
model.trainable = False

predictions = model(x_test)
predictions = tf.argmax(predictions, axis=1).numpy()

acc = sum(predictions == y_test) / len(y_test)
print(f"Test Accuracy: {acc * 100}%")


Test Accuracy: 93.58999999999999%
