In [3]:
import tensorflow as tf
import keras
import numpy as np

In [4]:
(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
num_classes = np.unique(y_test).size

# from integers in [0,255] to float in [0,1]
X_train = X_train.astype('float32') / 255
X_test  = X_test.astype('float32') / 255

# store the labels in 1D arrays, not 2D
y_train = np.squeeze(y_train)  # could do this with reshape
y_test = np.squeeze(y_test)


# Convert class vectors to one-hot encoded
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


In [5]:
class DynamicSizeBatchGenerator(keras.utils.Sequence):
  """ Generate batches that get smaller as training proceeds. """

  def __init__(self, X, y, *, start_size, end_size, num_epochs):
    super().__init__()
    self.X = X
    self.y = y
    self.start_size = start_size
    self.end_size = end_size
    self.num_epochs = num_epochs
    self.current_epoch = 0
    self.batch_size = start_size
  
  
  
  def __len__(self):
    """ Return the number of batches that can be produced. """
    
    # we can generate any number of batches, so provide a large number
    return 1000000
  
  @staticmethod
  def make_batch(X, y, batch_size):
    """ Make a random batch. """
    
    idx = np.random.choice(X.shape[0], batch_size)
    return X[idx], y[idx]

  def __getitem__(self, idx):
    """ Return a batch. """
    
    # Generate the batch
    return self.make_batch(self.X, self.y, self.batch_size)


  def on_epoch_end(self):
    """ Choose a different random subset of the data. """
    
    self.current_epoch += self.current_epoch
    self.batch_size = self.start_size + \
                      int(self.end_size * (self.current_epoch / self.num_epochs))


In [6]:
num_epochs = 15
train_generator = DynamicSizeBatchGenerator(
  X_train, y_train,
  start_size=128, end_size=4,
  num_epochs=num_epochs)


In [11]:
# Build the model
model = keras.Sequential([
  keras.layers.Input(shape=(32, 32, 3)),
  keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(128, activation='relu'),
  keras.layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
  loss='categorical_crossentropy',
  metrics=['accuracy'])

checkpointing = keras.callbacks.ModelCheckpoint(
  filepath="simple_cnn.keras",
  save_best_only=True,
  monitor="val_loss")


history = model.fit(
  train_generator, 
  epochs=num_epochs,
  steps_per_epoch=500, 
  batch_size=32,
  validation_data=(X_test, y_test),
  callbacks=[checkpointing]
)


Epoch 1/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 21ms/step - accuracy: 0.3660 - loss: 1.7535 - val_accuracy: 0.5468 - val_loss: 1.2796
Epoch 2/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 20ms/step - accuracy: 0.5754 - loss: 1.2075 - val_accuracy: 0.6127 - val_loss: 1.1168
Epoch 3/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 21ms/step - accuracy: 0.6286 - loss: 1.0621 - val_accuracy: 0.6352 - val_loss: 1.0457
Epoch 4/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 20ms/step - accuracy: 0.6667 - loss: 0.9610 - val_accuracy: 0.6642 - val_loss: 0.9784
Epoch 5/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 19ms/step - accuracy: 0.6955 - loss: 0.8786 - val_accuracy: 0.6748 - val_loss: 0.9488
Epoch 6/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 21ms/step - accuracy: 0.7127 - loss: 0.8224 - val_accuracy: 0.6827 - val_loss: 0.9492
Epoch 7/15
[1m5

In [8]:
history.history

{'accuracy': [0.46943750977516174,
  0.6072812676429749,
  0.6542656421661377,
  0.6893749833106995,
  0.713140606880188,
  0.7420156002044678,
  0.7612968683242798,
  0.7829687595367432,
  0.8005781173706055,
  0.8176875114440918,
  0.8350468873977661,
  0.8525312542915344,
  0.8645312786102295,
  0.8787343502044678,
  0.892328143119812],
 'loss': [1.4870588779449463,
  1.1257489919662476,
  1.0000250339508057,
  0.8990630507469177,
  0.8344011902809143,
  0.7520132660865784,
  0.694590151309967,
  0.6378592252731323,
  0.5833396315574646,
  0.5329967141151428,
  0.4884811341762543,
  0.43645063042640686,
  0.3984494209289551,
  0.3601619601249695,
  0.3220538794994354],
 'val_accuracy': [0.5468999743461609,
  0.629800021648407,
  0.6421999931335449,
  0.6757000088691711,
  0.6780999898910522,
  0.6894999742507935,
  0.6923999786376953,
  0.6948000192642212,
  0.6985999941825867,
  0.7063000202178955,
  0.7008000016212463,
  0.7063000202178955,
  0.7069000005722046,
  0.69929999113082