Skip to content

Commit

Permalink
Custom callback (#288)
Browse files Browse the repository at this point in the history
* Rework callback structure

* Add ChainCallback

* Fix ModelSaver args order
  • Loading branch information
tgallice authored and aymericdamien committed Aug 19, 2016
1 parent a1c0c37 commit 210d5d2
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 128 deletions.
183 changes: 102 additions & 81 deletions tflearn/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,78 @@

class Callback(object):
""" Callback base class. """

def __init__(self, **kwargs):
def on_train_begin(self, training_state):
pass

def on_epoch_begin(self, **kwargs):
def on_epoch_begin(self, training_state):
pass

def on_epoch_end(self, **kwargs):
def on_batch_begin(self, training_state):
pass

def on_sub_epoch_begin(self, **kwargs):
def on_sub_batch_begin(self, training_state):
pass

def on_sub_epoch_end(self, **kwargs):
def on_sub_batch_end(self, training_state, train_index=0):
pass

def on_batch_begin(self, **kwargs):
def on_batch_end(self, training_state, snapshot=False):
pass

def on_batch_end(self, **kwargs):
def on_epoch_end(self, training_state):
pass

def on_sub_batch_begin(self, **kwargs):
def on_train_end(self, training_state):
pass

def on_sub_batch_end(self, **kwargs):
pass
class ChainCallback(Callback):
def __init__(self, callbacks=[]):
self.callbacks = callbacks

def on_train_begin(self, **kwargs):
pass
def on_train_begin(self, training_state):
for callback in self.callbacks:
callback.on_train_begin(training_state)

def on_train_end(self, **kwargs):
pass
def on_epoch_begin(self, training_state):
for callback in self.callbacks:
callback.on_epoch_begin(training_state)

def on_batch_begin(self, training_state):
for callback in self.callbacks:
callback.on_batch_begin(training_state)

def on_sub_batch_begin(self, training_state):
for callback in self.callbacks:
callback.on_sub_batch_begin(training_state)

def on_sub_batch_end(self, training_state, train_index=0):
for callback in self.callbacks:
callback.on_sub_batch_end(training_state, train_index)

def on_batch_end(self, training_state, snapshot=False):
for callback in self.callbacks:
callback.on_batch_end(training_state, snapshot)

def on_epoch_end(self, training_state):
for callback in self.callbacks:
callback.on_epoch_end(training_state)

def on_train_end(self, training_state):
for callback in self.callbacks:
callback.on_train_end(training_state)

def add(self, callback):
if not isinstance(callback, Callback):
raise Exception(str(callback) + " is an invalid Callback object")

self.callbacks.append(callback)

class TermLogger(Callback):
def __init__(self, training_step=0):
super(TermLogger, self).__init__()
def __init__(self):
self.data = []
self.has_curses = True
self.has_ipython = True
self.display_type = "multi"
self.global_loss = None
self.global_acc = None
self.global_step = training_step
self.global_data_size = 0
self.global_val_data_size = 0
self.snapped = False
Expand Down Expand Up @@ -84,51 +111,43 @@ def add(self, data_size, val_size=0, metric_name=None, name=None):
self.global_data_size += data_size
self.global_val_data_size += val_size

def on_epoch_begin(self):
def on_epoch_begin(self, training_state):
pass

def on_epoch_end(self):
def on_epoch_end(self, training_state):
pass

def on_sub_epoch_begin(self):
def on_batch_begin(self, training_state):
pass

def on_sub_epoch_end(self, snapshot=False):
if snapshot:
self.snapshot_termlogs()
def on_batch_end(self, training_state, snapshot=False):

def on_batch_begin(self):
pass
self.print_termlogs(training_state)

def on_batch_end(self, global_loss=None, global_acc=None, snapshot=False):
self.global_step += 1
self.global_loss = global_loss
self.global_acc = global_acc
self.print_termlogs()
if snapshot:
self.snapshot_termlogs()
self.snapshot_termlogs(training_state)

def on_sub_batch_start(self):
def on_sub_batch_start(self, training_state):
pass

def on_sub_batch_end(self, train_op_i, epoch, step, loss=None, acc=None,
val_loss=None, val_acc=None):
self.data[train_op_i]['loss'] = loss
self.data[train_op_i]['acc'] = acc
self.data[train_op_i]['val_loss'] = val_loss
self.data[train_op_i]['val_acc'] = val_acc
self.data[train_op_i]['epoch'] = epoch
self.data[train_op_i]['step'] = step
def on_sub_batch_end(self, training_state, train_index=0):

def on_train_begin(self):
self.data[train_index]['loss'] = training_state.loss_value
self.data[train_index]['acc'] = training_state.acc_value
self.data[train_index]['val_loss'] = training_state.val_loss
self.data[train_index]['val_acc'] = training_state.val_acc
self.data[train_index]['epoch'] = training_state.epoch
self.data[train_index]['step'] = training_state.current_iter

def on_train_begin(self, training_state):
print("---------------------------------")
print("Training samples: " + str(self.global_data_size))
print("Validation samples: " + str(self.global_val_data_size))
print("--")
if len(self.data) == 1:
self.display_type = "single"

def on_train_end(self):
def on_train_end(self, training_state):
# Reset caret to last position
to_be_printed = ""
if self.has_curses: #if not self.has_ipython #TODO:check bug here
Expand All @@ -143,14 +162,14 @@ def on_train_end(self):
if self.has_curses:
sys.stdout.write(curses.tigetstr('cvvis').decode())

def termlogs(self):
def termlogs(self, step=0, global_loss=None, global_acc=None):

termlogs = "Training Step: " + str(self.global_step) + " "
if self.global_loss:
termlogs = "Training Step: " + str(step) + " "
if global_loss:
termlogs += " | total loss: \033[1m\033[32m" + \
"%.5f" % self.global_loss + "\033[0m\033[0m"
if self.global_acc and not self.display_type == "single":
termlogs += " - avg acc: %.4f" % float(self.global_acc)
"%.5f" % global_loss + "\033[0m\033[0m"
if global_acc and not self.display_type == "single":
termlogs += " - avg acc: %.4f" % float(global_acc)
termlogs += "\n"
for i, data in enumerate(self.data):
print_loss = ""
Expand Down Expand Up @@ -184,9 +203,13 @@ def termlogs(self):

return termlogs

def print_termlogs(self):
def print_termlogs(self, training_state):

termlogs = self.termlogs(
step=training_state.step,
global_loss=training_state.global_loss,
global_acc=training_state.global_acc)

termlogs = self.termlogs()
if self.has_ipython and not self.has_curses:
clear_output(wait=True)
else:
Expand All @@ -196,67 +219,65 @@ def print_termlogs(self):
sys.stdout.write(termlogs)
sys.stdout.flush()

def snapshot_termlogs(self):
def snapshot_termlogs(self, training_state):

termlogs = self.termlogs(
step=training_state.step,
global_loss=training_state.global_loss,
global_acc=training_state.global_acc)

termlogs = self.termlogs()
termlogs += "--\n"

sys.stdout.write(termlogs)
sys.stdout.flush()
self.snapped = True


class ModelSaver(object):
def __init__(self, save_func, training_step, snapshot_path, best_snapshot_path,
class ModelSaver(Callback):
def __init__(self, save_func, snapshot_path, best_snapshot_path,
best_val_accuracy, snapshot_step, snapshot_epoch):
self.save_func = save_func
self.training_step = training_step
self.snapshot_path = snapshot_path
self.snapshot_epoch = snapshot_epoch
self.best_snapshot_path = best_snapshot_path
self.snapshot_step = snapshot_step
self.best_val_accuracy = best_val_accuracy
self.snapshot_step = snapshot_step

def on_epoch_begin(self):
def on_epoch_begin(self, training_state):
pass

def on_epoch_end(self):
def on_epoch_end(self, training_state):
if self.snapshot_epoch:
self.save()
self.save(training_state.step)

def on_sub_epoch_begin(self):
def on_batch_begin(self, training_state):
pass

def on_sub_epoch_end(self):
pass
def on_batch_end(self, training_state, snapshot=False):

def on_batch_begin(self):
pass
if snapshot & (self.snapshot_step is not None):
self.save(training_state.step)

def on_batch_end(self, snapshot_model=False, best_checkpoint_path=None, val_accuracy=None):
self.training_step += 1
if snapshot_model & (self.snapshot_step is not None):
self.save()
if None not in (best_checkpoint_path, val_accuracy, self.best_val_accuracy):
if val_accuracy > self.best_val_accuracy:
self.best_val_accuracy = val_accuracy
self.save_best(int(10000 * round(val_accuracy, 4)))
if None not in (self.best_snapshot_path, self.best_val_accuracy, training_state.val_acc):
if training_state.val_acc > self.best_val_accuracy:
self.best_val_accuracy = training_state.val_acc
self.save_best(int(10000 * round(training_state.val_acc, 4)))

def on_sub_batch_begin(self):
def on_sub_batch_begin(self, training_state):
pass

def on_sub_batch_end(self):
def on_sub_batch_end(self, training_state, train_index=0):
pass

def on_train_begin(self):
def on_train_begin(self, training_state):
pass

def on_train_end(self):
def on_train_end(self, training_state):
pass

def save(self):
def save(self, training_step=0):
if self.snapshot_path:
self.save_func(self.snapshot_path, self.training_step)
self.save_func(self.snapshot_path, training_step)

def save_best(self, val_accuracy):
if self.best_snapshot_path:
Expand Down

0 comments on commit 210d5d2

Please sign in to comment.