In [1]:
import numpy as np
import tensorflow as tf
import time
import matplotlib.pyplot as plt
import pandas as pd
import os


In [2]:
data_size = 1000
# 80% of the data is for training.
train_pct = 0.8

train_size = int(data_size * train_pct)

# Create some input data between -1 and 1 and randomize it.
x = np.linspace(-1, 1, data_size)
np.random.shuffle(x)

# Generate the output data.
# y = 0.5x + 2 + noise
y = 0.5 * x + 2 + np.random.normal(0, 0.05, (data_size, ))

# Split into test and train pairs.
x_train, y_train = x[:train_size], y[:train_size]
x_test, y_test = x[train_size:], y[train_size:]

In [3]:
class CSVLogger(tf.keras.callbacks.Callback):
    def __init__(self, path):
        
        self.path = path
        self.timetaken = time.time()
        self.state = {}
        
    def on_epoch_end(self, epoch, logs = {}):
        
        logs['time'] = time.time() - self.timetaken
        self.state[epoch] = logs
                
    def on_train_end(self, logs = {}):
        
        headers = []
        for k, v in self.state.items():
            headers = self.state[k].keys()
            break
        
        data = { k:self.state[k].values() for k, v in self.state.items() }
        df = pd.DataFrame.from_dict(data, orient='index')
        df.columns = headers
        df.to_csv(self.path)

In [5]:
factory = CallbacksFactory()

factory.create()

[<__main__.CSVLogger at 0x29254dd13d0>]

In [6]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, input_dim=1),
    tf.keras.layers.Dense(1),
])

model.compile(
    loss='mse', # keras.losses.mean_squared_error
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.2),
)

print("Training ... With default parameters, this takes less than 10 seconds.")

training_history = model.fit(
    x_train, # input
    y_train, # output
    batch_size=train_size,
    verbose=0, # Suppress chatty output; use Tensorboard instead
    epochs=100,
    validation_data=(x_test, y_test),
    callbacks=factory.create(),
)

Training ... With default parameters, this takes less than 10 seconds.


In [4]:
class CallbacksFactory:
    def __init__(self):
        pass
    
    def create(self, directory: str = '.', filename: str = 'data.csv' ) -> list: 
        callbacks = [
            CSVLogger(os.sep.join([directory, filename]))
        ]
        return callbacks