# Using a Python generator for data-loading with Keras

Using the `model.fit(...)` method to run the training of your Keras model is not feasible for data, which does not fit into memory. Therefore, Keras offers the `model.fit_generator(...)` method, which takes a Python generator function as dataloader. This enables you to load data "on-the-fly" from datasets, which do not fit into memory.

In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
np.random.seed(1234)
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers.core import Dense

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


## Model definition

This simple model serves as example.

In [2]:
model = Sequential()
model.add(Dense(100, activation="relu", input_shape=(2,)))
model.add(Dense(1, activation="sigmoid"))

model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

## Python generator as dataloader

The following generator implements a dataloader, which generates toy examples on the fly. However, you are free to implement loading batches of examples from disk in this function. Note that the generator can loop infinitely and just have to return with each `yield` statement a batch of inputs and labels, which will be used for a single gradient step during training.

In [3]:
def data_generator(batch_size):
    signal_mean = [1.0, 1.0]
    signal_cov = [[1.0, 0.0],
                  [0.0, 1.0]]
    background_mean = [-1.0, -1.0]
    background_cov = [[1.0, 0.0],
                      [0.0, 1.0]]
    
    while True:
        signal = np.random.multivariate_normal(signal_mean, signal_cov, batch_size/2)
        background = np.random.multivariate_normal(background_mean, background_cov, batch_size/2)
        
        inputs = np.vstack([signal, background])
        labels = np.vstack([np.ones((batch_size/2, 1)), np.zeros((batch_size/2, 1))])
        
        yield inputs, labels

## Run the training

As you can see below, the training is similar to the `model.fit(...)` method. As additional feature, you can load data with multiple workers in the background to a buffer in memory similar to TensorFlow's queue system.

In [4]:
model.fit_generator(
    data_generator(batch_size=100),
    steps_per_epoch=10,
    epochs=10,
    max_queue_size=10,
    workers=1);

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
