In [17]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [18]:
# Load MNIST
(x_train, y_train), _ = mnist.load_data()

# Normalize and flatten images
x_train = x_train.reshape(-1, 784).astype(np.float32) / 255.0
y_train = to_categorical(y_train, 10)

# Use a small batch
x_batch = x_train[:64].T  # shape: [784, 64]
y_batch = y_train[:64].T  # shape: [10, 64]

In [19]:
input_size = 784
hidden_size = 128
output_size = 10

np.random.seed(0)

# Initialize weights
W1 = np.random.randn(hidden_size, input_size) * 0.01  # input -> hidden
W2 = np.random.randn(output_size, hidden_size) * 0.01  # hidden -> output

# Initialize states
x0 = x_batch  # input layer
x1 = np.zeros((hidden_size, x0.shape[1]))  # hidden layer
x2 = np.zeros((output_size, x0.shape[1]))  # output layer

# Learning rates
lr_w = 0.005
lr_s = 0.1

In [20]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=0, keepdims=True))
    return e_x / np.sum(e_x, axis=0, keepdims=True)

In [21]:
for step in range(5):  # Iterative refinement
    # Predictions
    pred_x0 = W1.T @ x1
    pred_x1 = W2.T @ x2

    # Errors
    err_x0 = x0 - pred_x0
    err_x1 = x1 - pred_x1
    err_x2 = y_batch - softmax(x2)  # target-based prediction error

    # Update states
    x1 += lr_s * (W1 @ err_x0 - W2.T @ err_x2)
    x2 += lr_s * (W2 @ err_x1 + err_x2)

    #Corrected weight updates
    W1 += lr_w * (x1 @ err_x0.T)
    W2 += lr_w * (x2 @ err_x1.T)


In [22]:
predictions = np.argmax(softmax(x2), axis=0)
true_labels = np.argmax(y_batch, axis=0)
accuracy = np.mean(predictions == true_labels)

print(f"Batch Accuracy: {accuracy * 100:.2f}%")


Batch Accuracy: 100.00%


this is just to check one many batchs

In [23]:
lr_s = 0.01  # or even 0.001
lr_w = 0.0005

In [24]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=0, keepdims=True))
    return e_x / (np.sum(e_x, axis=0, keepdims=True) + 1e-9)

In [25]:
softmax(x2)

array([[0.09432863, 0.15351399, 0.09415026, 0.09434243, 0.09417163,
        0.09441846, 0.09433132, 0.0943647 , 0.0941853 , 0.09416097,
        0.09440682, 0.0942309 , 0.09436793, 0.09429801, 0.09423532,
        0.09430127, 0.09443343, 0.09439276, 0.09424347, 0.09419211,
        0.09415506, 0.15358651, 0.0941452 , 0.09434225, 0.0943285 ,
        0.09445949, 0.09416818, 0.0945475 , 0.09443978, 0.09438285,
        0.09430853, 0.09447975, 0.09429772, 0.09422948, 0.15350792,
        0.09415763, 0.09439632, 0.15348361, 0.09434936, 0.09431512,
        0.09422409, 0.09439834, 0.09430589, 0.09424782, 0.09427093,
        0.09421302, 0.09432952, 0.09425607, 0.09420317, 0.09439225,
        0.09438574, 0.15362729, 0.09431653, 0.09415212, 0.09418086,
        0.0944264 , 0.15353231, 0.09420035, 0.09422318, 0.09431532,
        0.0942762 , 0.09417423, 0.09436537, 0.15360289],
       [0.09407588, 0.09408966, 0.09398612, 0.15292316, 0.09409637,
        0.09398211, 0.15298312, 0.09390925, 0.15304975, 0.0