In [120]:
import numpy as np
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

In [1]:
class Callback(): pass

## Experiments with `state_dict`

In [51]:
class Callback():
    
    state_dict = {}

In [62]:
class CallbackA(Callback): 
    
    def print_state_dict(self):
        print(self.state_dict)

In [63]:
class CallbackB(Callback): pass

In [64]:
CallbackA.state_dict == CallbackB.state_dict

True

In [65]:
CallbackB.state_dict.update({'apple':1})

When CallbackB makes a change to its state_dict, CallbackA also perceives that change!

In [66]:
CallbackA.state_dict

{'apple': 1, 'pear': 2}

In [67]:
cb_a, cb_b = CallbackA(), CallbackB()

In [68]:
cb_a.state_dict.update({'pear':2})

In [69]:
cb_a.print_state_dict()

{'apple': 1, 'pear': 2}


In [70]:
cb_b.state_dict

{'apple': 1, 'pear': 2}

In [60]:
CallbackA.state_dict

{'apple': 1, 'pear': 2}

## Build callbacks and test them with tensorboard

In [203]:
class Callback(): 
    sd = {}
    def on_train_begin(self): pass
    def on_epoch_begin(self): pass
    def on_batch_begin(self): pass
    def on_loss_begin(self): pass
    def on_backward_begin(self): pass
    def on_backward_end(self): pass
    def on_step_end(self): pass
    def on_batch_end(self): pass
    def on_epoch_end(self): pass
    def on_train_end(self): pass

In [269]:
class CallbackHandler(Callback): 
    
    def __init__(self, cbs):
        self.cbs = cbs
    
    def __call__(self, cb_category:str):
        self.cbs = sorted(self.cbs, key=lambda cb : cb._order)
        for cb in self.cbs: getattr(cb, cb_category)()
    
    def on_train_begin(self): self('on_train_begin')
        
    def on_epoch_begin(self): self('on_epoch_begin')
    
    def on_batch_begin(self): self('on_batch_begin')
        
    def on_loss_begin(self): self('on_loss_begin')
        
    def on_backward_begin(self): self('on_backward_begin')
        
    def on_backward_end(self): self('on_backward_end')
        
    def on_step_end(self): self('on_step_end')
        
    def on_batch_end(self): self('on_batch_end')
    
    def on_epoch_end(self): self('on_epoch_end')
        
    def on_train_end(self): self('on_train_end')

In [349]:
class TensorboardCreator(Callback):
    _order=0
    
    def __init__(self, log_dir):
        self.log_dir = log_dir
        import os
        assert not os.path.isdir(log_dir), AssertionError('log_dir already exists; try another one')
    
    def on_train_begin(self):
        self.sd.update({'writer':SummaryWriter(log_dir=self.log_dir)})
        
    def on_train_end(self):
        self.sd['writer'].flush()
        self.sd['writer'].close()
        del self.sd['writer']

In [350]:
class MetricLogger(Callback):
    _order=1
    
    def __init__(self, metric_name:str, group:str, on_tensorboard:bool=False):
        
        self.metric_name = metric_name
        self.group = group
        self.on_tensorboard = on_tensorboard

    def on_train_begin(self):
        self.sd[f'last_{self.metric_name}'] = None
        self.sd[f'{self.metric_name}s'] = []

    def on_epoch_begin(self):
        self.total = 0
        self.num_examples = 0
    
    def on_batch_end(self):
        self.total += self.sd[self.metric_name]
        self.num_examples += self.sd['batch_size'] 
        
    def on_epoch_end(self):
        
        last = self.total / self.num_examples
        self.sd[f'last_{self.metric_name}'] = last
        self.sd[f'{self.metric_name}s'].append(last)
        
        if self.on_tensorboard:
            self.sd['writer'].add_scalar(
                f'{self.group}/{self.metric_name}', 
                self.sd[f'last_{self.metric_name}'], 
                self.sd['epoch']
            )

In [351]:
def fake_train(cb_handler):
    cb_handler.on_train_begin()
    for i in range(100):
        cb_handler.on_epoch_begin()
        for j in range(10):
            cb_handler.on_batch_begin()
            cb_handler.on_loss_begin()
            cb_handler.sd.update({'epoch':i, 'batch_size':64, 'bce_loss':np.random.random()*64})
            cb_handler.on_backward_begin()
            cb_handler.on_backward_end()
            cb_handler.on_step_end()
            cb_handler.on_batch_end()
        cb_handler.on_epoch_end()
    cb_handler.on_train_end()

In [352]:
exp_name = 'testing_callbacks'
trial = 1

In [353]:
cb_handler = CallbackHandler([
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}'),  # automatically mkdir
    MetricLogger(metric_name='bce_loss', group='train', on_tensorboard=True)
])
fake_train(cb_handler)

AssertionError: log_dir already exists; try another one