In [1]:
import tensorflow as tf
import keras
import numpy as np

In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
num_classes = np.unique(y_test).size

# from integers in [0,255] to float in [0,1]
X_train = X_train.astype('float32') / 255
X_test  = X_test.astype('float32') / 255

# store the labels in 1D arrays, not 2D
y_train = np.squeeze(y_train)  # could do this with reshape
y_test = np.squeeze(y_test)


# Convert class vectors to one-hot encoded
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


In [3]:
class DynamicSizeBatchGenerator(keras.utils.Sequence):
  """ Generate batches that get smaller as training proceeds. """

  def __init__(self, X, y, *, start_size, end_size, num_epochs):
    super().__init__()
    self.X = X
    self.y = y
    self.start_size = start_size
    self.end_size = end_size
    self.num_epochs = num_epochs
    self.current_epoch = 0
    self.batch_size = start_size
  
  
  
  def __len__(self):
    """ Return the number of batches that can be produced. """
    
    # we can generate any number of batches, so provide a large number
    return 1000000
  
  @staticmethod
  def make_batch(X, y, batch_size):
    """ Make a random batch. """
    
    idx = np.random.choice(X.shape[0], batch_size)
    return X[idx], y[idx]

  def __getitem__(self, idx):
    """ Return a batch. """
    
    # Generate the batch
    return self.make_batch(self.X, self.y, self.batch_size)


  def on_epoch_end(self):
    """ Choose a different random subset of the data. """
    
    self.current_epoch += self.current_epoch
    self.batch_size = self.start_size + \
                      int(self.end_size * (self.current_epoch / self.num_epochs))


In [4]:
num_epochs = 15
train_generator = DynamicSizeBatchGenerator(
  X_train, y_train,
  start_size=128, end_size=4,
  num_epochs=num_epochs)


In [5]:
# Build the model
model = keras.Sequential([
  keras.layers.Input(shape=(32, 32, 3)),
  keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(128, activation='relu'),
  keras.layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
  loss='categorical_crossentropy',
  metrics=['accuracy'])


history = model.fit(
  train_generator, 
  epochs=num_epochs,
  steps_per_epoch=500, 
  batch_size=32,
  validation_data=(X_test, y_test)
)


Epoch 1/15


2024-10-20 20:45:21.308729: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Max
2024-10-20 20:45:21.308762: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-10-20 20:45:21.308769: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-10-20 20:45:21.308785: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-10-20 20:45:21.308795: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2024-10-20 20:45:21.643645: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.3927 - loss: 1.6802 - val_accuracy: 0.5752 - val_loss: 1.2073
Epoch 2/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.6054 - loss: 1.1340 - val_accuracy: 0.6199 - val_loss: 1.0870
Epoch 3/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.6656 - loss: 0.9692 - val_accuracy: 0.6481 - val_loss: 1.0190
Epoch 4/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.6981 - loss: 0.8800 - val_accuracy: 0.6777 - val_loss: 0.9510
Epoch 5/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.7232 - loss: 0.7993 - val_accuracy: 0.6805 - val_loss: 0.9419
Epoch 6/15
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.7532 - loss: 0.7246 - val_accuracy: 0.6840 - val_loss: 0.9571
Epoch 7/15
[1m500/500[0m [32m

In [7]:
history.history

{'accuracy': [0.48807811737060547,
  0.620718777179718,
  0.6732968688011169,
  0.707156240940094,
  0.7323125004768372,
  0.7605156302452087,
  0.785406231880188,
  0.8020312786102295,
  0.8264843821525574,
  0.8483906388282776,
  0.8657812476158142,
  0.8845937252044678,
  0.898812472820282,
  0.9125937223434448,
  0.9277499914169312],
 'loss': [1.4356911182403564,
  1.090832233428955,
  0.9471660256385803,
  0.8534831404685974,
  0.7767249345779419,
  0.7011699080467224,
  0.637131929397583,
  0.5771342515945435,
  0.5132456421852112,
  0.45047658681869507,
  0.40002307295799255,
  0.3470534682273865,
  0.3082069158554077,
  0.26369985938072205,
  0.2253839522600174],
 'val_accuracy': [0.5752000212669373,
  0.6198999881744385,
  0.6481000185012817,
  0.6776999831199646,
  0.6804999709129333,
  0.6840000152587891,
  0.7020999789237976,
  0.6974999904632568,
  0.7067000269889832,
  0.6934999823570251,
  0.704800009727478,
  0.7017999887466431,
  0.6944000124931335,
  0.684199988842010