In [1]:
import tensorflow as tf

from neuro.nn import activation, layer, losses, models, optimizer


In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

y_train = tf.one_hot(y_train, 10)

In [3]:
model = models.Sequential(
    layer.Flatten(),
    layer.Dense(28*28, 512), activation.ReLU(),
    layer.Dropout(0.25),
    layer.Dense(512, 128), activation.ReLU(),
    layer.Dense(128, 16), activation.ReLU(),
    layer.Dense(16, 10), activation.StableSoftmax(),
)


In [4]:
loss = losses.CategoricalCrossentropy()
optim = optimizer.Adam()


In [5]:
epochs = 100
for i in range(epochs):
    # Forward Propagation
    y_pred = model(x_train)

    # Calculation of Loss
    train_loss = loss(y_pred, y_train)
    print(f"Epoch: {i + 1}, Loss: {train_loss.numpy()}")

    # Back Propagation + Optimizing
    optim(model, loss)


Epoch: 1, Loss: 2.302288770675659
Epoch: 2, Loss: 2.3001456260681152
Epoch: 3, Loss: 2.2966978549957275
Epoch: 4, Loss: 2.290881395339966
Epoch: 5, Loss: 2.2811803817749023
Epoch: 6, Loss: 2.2654495239257812
Epoch: 7, Loss: 2.2421159744262695
Epoch: 8, Loss: 2.2083616256713867
Epoch: 9, Loss: 2.1612305641174316
Epoch: 10, Loss: 2.0998427867889404
Epoch: 11, Loss: 2.0252325534820557
Epoch: 12, Loss: 1.938755750656128
Epoch: 13, Loss: 1.8438398838043213
Epoch: 14, Loss: 1.7441458702087402
Epoch: 15, Loss: 1.6383346319198608
Epoch: 16, Loss: 1.5295225381851196
Epoch: 17, Loss: 1.4193378686904907
Epoch: 18, Loss: 1.3038674592971802
Epoch: 19, Loss: 1.2033330202102661
Epoch: 20, Loss: 1.115446925163269
Epoch: 21, Loss: 1.037292242050171
Epoch: 22, Loss: 0.9618557095527649
Epoch: 23, Loss: 0.8948253393173218
Epoch: 24, Loss: 0.8426576852798462
Epoch: 25, Loss: 0.8037654757499695
Epoch: 26, Loss: 0.7706952095031738
Epoch: 27, Loss: 0.7324863076210022
Epoch: 28, Loss: 0.7047907710075378
Epoch:

In [6]:
model.trainable = False

predictions = model(x_test)
predictions = tf.argmax(predictions, axis=1).numpy()

acc = sum(predictions == y_test) / len(y_test)
print(f"Test Accuracy: {acc * 100}%")


Test Accuracy: 92.80000000000001%
