In [3]:
import keras
import tensorflow as tf
import numpy as np
from keras import backend as K

In [5]:
class TB_writer(keras.callbacks.Callback):
    def __init__(self, log_dir='/data/tensorflow/logs',
                 histogram_freq=0,
                 batch_size=32,
                 write_graph=True,
                 write_grads=True,
                 write_images=False,
                 val_gen=None):
        super(TB_writer, self).__init__()
        self.log_dir = log_dir
        self.historgram_freq = historgram_freq
        self.write_graph = write_graph
        self.write_grads = write_grads
        self.write_images = write_images
        self.batch_size = batch_size
        self.merged = None
        self.val_gen = val_gen
    def set_model(self, model):
        self.model = model
        self.sess = K.get_session()
        if self.historgram_freq and self.merged is None:
            for layer in self.model.layers:
                for weight in layer.weights:
                    mapped_weight_name = weight.name.replace(':', '_')
                    tf.summary.histogram(mapped_weight_name, weight)
                    
                    if self.write_grads:
                        grads = model.optimizer.get_gradients(model.total_loss, weight)
                        tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
                        
                    if self.write_images:
                        w_img = tf.squeeze(weight)
                        shape = K.int_shape(w_img)
                        if len(shape)==2: #dense layer
                            if shape[0] > shape[1]:
                                w_img = tf.transpose(w_img)
                                shape = K.int_shape(w_img)
                            w_img = tf.reshape(w_img, [1, shape[0], shape[1], 1])
                        elif len(shape == 3): #convnet check
                            if K.image_data_format() == 'channels_last':
                                w_img = tf.transpose(w_img, perm[2, 0, 1])
                                shape = K.int_shape(w_img)
                            w_img = tf.reshape(w_img [shape[0], shape[1], shape[2], 1])
                        elif len(shape)==1: #bias case
                            w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
                        else:
                            # maybe cant handle 3d convnnets
                            continue
                        shape = K.int_shape(w_img)
                        assert len(shape == 4 and shape[-1] in [1, 3, 4])
                        tf.summary.image(mapped_weight_name,w_img)
                        
                if hasattr(layer, 'output'):
                    tf.summary.historgram('{}_out'.format(layer.name), layer.output)
                    
            self.merged = tf.summary.merge_all()
            if self.write_graph:
                self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
            else:
                self.writer = tf.summary.FileWriter(self.log_dir)
                
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if self.val_gen and self.histogram_freq:
            if epoch % self.histogram_freq == 0:
                val_data = self.val_gen.next()
                tensors = (self.model.inputs +
                           self.model.targets)

                if self.model.uses_learning_phase:
                    tensors += [K.learning_phase()]

                assert len(val_data) == len(tensors)
                val_size = val_data[0].shape[0]
                i = 0
                while i < val_size:
                    step = min(self.batch_size, val_size - i)
                    batch_val = []
                    batch_val.append(val_data[0][i:i + step])
                    batch_val.append(val_data[1][i:i + step])
                    if self.model.uses_learning_phase:
                        batch_val.append(val_data[2])
                    feed_dict = dict(zip(tensors, batch_val))
                    result = self.sess.run([self.merged], feed_dict=feed_dict)
                    summary_str = result[0]
                    self.writer.add_summary(summary_str, epoch)
                    i += self.batch_size
        for name, value in logs.items():
            if name in ['batch', 'size']:
                continue
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.writer.add_summary(summary, epoch)
        self.writer.flush()

    def on_train_end(self, _):
        self.writer.close()                        

In [6]:
dir(keras.callbacks.Callback)

['__class__',
 '__delattr__',
 '__dict__',
 '__doc__',
 '__format__',
 '__getattribute__',
 '__hash__',
 '__init__',
 '__module__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'on_batch_begin',
 'on_batch_end',
 'on_epoch_begin',
 'on_epoch_end',
 'on_train_begin',
 'on_train_end',
 'set_model',
 'set_params']