In [None]:
import tensorflow as tf

In [None]:
import math

In [None]:
class NaiveDense:
  def __init__(self, input_size, output_size, activation):
    self.activation = activation

    w_shape = (input_size, output_size)
    w_initial_value = tf.random.uniform(w_shape, minval=0, maxval=1e-1)
    self.W = tf.Variable(w_initial_value)

    b_shape = (output_size, )
    b_initial_value = tf.zeros(b_shape)
    self.b = tf.Variable(b_initial_value)

  def __call__(self, inputs):
    return self.activation(tf.matmul(inputs, self.W) + self.b)

  @property
  def weights(self):
    return [self.W, self.b]

In [None]:
class NaiveSequential:
  def __init__(self, layers):
    self.layers = layers

  def __call__(self, inputs):
    x = inputs
    for layer in self.layers:
      x = layer(x)
    return x

  @property
  def weights(self):
    weights = []
    for layer in self.layers:
      weights += layer.weights
    return weights

In [None]:
model = NaiveSequential([
                         NaiveDense(input_size=28 * 28, output_size=512, activation=tf.nn.relu),
                         NaiveDense(input_size=512, output_size=10, activation=tf.nn.softmax)
])

assert len(model.weights) == 4

In [None]:
class BatchGenerator:
  def __init__(self, images, labels, batch_size=128):
    assert len(images) == len(labels)

    self.index = 0
    self.images = images
    self.labels = labels
    self.batch_size = batch_size
    self.num_batches = math.ceil(len(images)/batch_size)

  def next(self):
    images = self.images[self.index: self.index + self.batch_size]
    labels = self.labels[self.index: self.index + self.batch_size]

    self.index += self.batch_size

    return images, labels

    

In [None]:
def one_training_step(model, images_batch, labels_batch):
  with tf.GradientTape() as tape:
    predictions = model(images_batch)
    per_sample_losses = tf.keras.losses.sparse_categorical_crossentropy(
        labels_batch, predictions)
    
    average_loss = tf.reduce_mean(per_sample_losses)
  
  gradients = tape.gradient(average_loss, model.weights)
  update_weights(gradients, model.weights)

  return average_loss

In [None]:
def update_weights(gradients, weights):
  learning_rate = 1e-3
  for g, w in zip(gradients, weights):
    w.assign_sub(g*learning_rate)

In [None]:
from tensorflow.keras import optimizers
optimizer = optimizers.SGD(learning_rate=1e-3)

def update_weights(gradients, weights):
  optimizer.apply_gradients(zip(gradients, weights))

In [None]:
def fit(model, images, labels, epochs, batch_size=128):
  for epoch_ctr in range(epochs):
    print(f"Epoch: {epoch_ctr + 1}")

    batch_generator = BatchGenerator(images, labels)
    for batch_ctr in range(batch_generator.num_batches):
      images_batch, labels_batch = batch_generator.next()
      loss = one_training_step(model, images_batch, labels_batch)
      if batch_ctr % 100 == 0:
        print(f"Loss at batch {batch_ctr}: {loss:.4f}")

In [None]:
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape((len(train_images), 28*28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((len(test_images), 28 * 28))
test_images = test_images.astype("float32") / 255

fit(model, train_images, train_labels, epochs=20, batch_size=128)

Epoch: 0
Loss at batch 0: 0.4869
Loss at batch 100: 0.4863
Loss at batch 200: 0.4194
Loss at batch 300: 0.4922
Loss at batch 400: 0.5937
Epoch: 1
Loss at batch 0: 0.4734
Loss at batch 100: 0.4706
Loss at batch 200: 0.4059
Loss at batch 300: 0.4791
Loss at batch 400: 0.5833
Epoch: 2
Loss at batch 0: 0.4612
Loss at batch 100: 0.4566
Loss at batch 200: 0.3939
Loss at batch 300: 0.4673
Loss at batch 400: 0.5740
Epoch: 3
Loss at batch 0: 0.4501
Loss at batch 100: 0.4440
Loss at batch 200: 0.3832
Loss at batch 300: 0.4567
Loss at batch 400: 0.5656
Epoch: 4
Loss at batch 0: 0.4400
Loss at batch 100: 0.4325
Loss at batch 200: 0.3736
Loss at batch 300: 0.4471
Loss at batch 400: 0.5580
Epoch: 5
Loss at batch 0: 0.4308
Loss at batch 100: 0.4220
Loss at batch 200: 0.3648
Loss at batch 300: 0.4384
Loss at batch 400: 0.5511
Epoch: 6
Loss at batch 0: 0.4224
Loss at batch 100: 0.4125
Loss at batch 200: 0.3569
Loss at batch 300: 0.4304
Loss at batch 400: 0.5447
Epoch: 7
Loss at batch 0: 0.4146
Loss at 