In [3]:
import numpy as np
import pandas as pd
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import os
import shutil
import scipy.misc

import sys
sys.path.append('../modules')

## Experiments with `state_dict`

In [4]:
class Callback():
    
    state_dict = {}

In [5]:
class CallbackA(Callback): 
    
    def print_state_dict(self):
        print(self.state_dict)

In [6]:
class CallbackB(Callback): pass

In [7]:
CallbackA.state_dict == CallbackB.state_dict

True

In [8]:
CallbackB.state_dict.update({'apple':1})

When CallbackB makes a change to its state_dict, CallbackA also perceives that change!

In [9]:
CallbackA.state_dict

{'apple': 1}

In [10]:
cb_a, cb_b = CallbackA(), CallbackB()

In [11]:
cb_a.state_dict.update({'pear':2})

In [12]:
cb_a.print_state_dict()

{'apple': 1, 'pear': 2}


In [13]:
cb_b.state_dict

{'apple': 1, 'pear': 2}

In [14]:
CallbackA.state_dict

{'apple': 1, 'pear': 2}

## Learner
container for data, model, loss_fn, optim

In [15]:
#export
class Learner():
    def __init__(self, train_data, model, loss, optim, valid_data=None):
        self.train_data, self.model, self.loss, self.optim, self.valid_data = train_data, model, loss, optim, valid_data

## Callback base class, custom callbacks, CallbackHandler

In [16]:
class Callback(): 
    sd = {}
    def on_train_begin(self): pass  # make sure to reset self.sd because it is a class attribute
    def on_epoch_begin(self): pass
    def on_batch_begin(self): pass
    def on_loss_begin(self): pass
    def on_backward_begin(self): pass
    def on_backward_end(self): pass
    def on_step_end(self): pass
    def on_batch_end(self): pass
    def on_epoch_end(self): pass
    def on_train_end(self): pass

In [17]:
class CallbackHandler(Callback): 
    
    def __init__(self, cbs):
        self.cbs = cbs
    
    def __call__(self, cb_category:str):
        self.cbs = sorted(self.cbs, key=lambda cb : cb._order)
        for cb in self.cbs: getattr(cb, cb_category)()
    
    def on_train_begin(self): self('on_train_begin')
        
    def on_epoch_begin(self): self('on_epoch_begin')
    
    def on_batch_begin(self): self('on_batch_begin')
        
    def on_loss_begin(self): self('on_loss_begin')
        
    def on_backward_begin(self): self('on_backward_begin')
        
    def on_backward_end(self): self('on_backward_end')
        
    def on_step_end(self): self('on_step_end')
        
    def on_batch_end(self): self('on_batch_end')
    
    def on_epoch_end(self): self('on_epoch_end')
        
    def on_train_end(self): self('on_train_end')

In [18]:
def ifoverwrite(overwrite:bool, file_navigator:str, file_navigator_type:str)->None:
    
    if overwrite:
        
        if file_navigator_type == 'path':
            file_dir = ''.join([folder + '/' for folder in file_navigator.split('/')[:-1]])
            if not (os.path.isfile(file_navigator_type) or os.path.isdir(file_dir)):
                os.makedirs(file_dir, exist_ok=True)
            
        elif file_navigator_type == 'dir':
            file_dir = file_navigator
            if os.path.isdir(file_dir):  # so that overwrite also works when there's nothing to overwrite
                shutil.rmtree(file_dir)
                os.makedirs(file_dir, exist_ok=False)
            
    else:
        assert not (os.path.isfile(file_navigator) or os.path.isdir(file_navigator)), \
        AssertionError(f'{file_navigator} already exists. To overwrite it, pass True to argument `overwrite`.')

In [19]:
class TensorboardCreator(Callback):
    _order=0
    
    def __init__(self, log_dir, overwrite:bool):
        self.log_dir = log_dir
        ifoverwrite(overwrite, log_dir, 'dir')
    
    def on_train_begin(self):
        self.sd.update({'writer':SummaryWriter(log_dir=self.log_dir)})
        
    def on_train_end(self):
        self.sd['writer'].flush()
        self.sd['writer'].close()
        del self.sd['writer']

In [20]:
class MetricLogger(Callback):
    """Log (in a list) and visualize metric values over time."""
    _order=1
    
    def __init__(self, metric_name:str, group:str, on_tensorboard:bool):
        
        self.metric_name = metric_name
        self.group = group
        self.on_tensorboard = on_tensorboard

    def on_train_begin(self):
        self.sd[f'last_{self.metric_name}'] = None
        self.sd[f'{self.metric_name}s'] = []

    def on_epoch_begin(self):
        self.total = 0
        self.num_examples = 0
    
    def on_batch_end(self):
        self.total += self.sd[f'{self.metric_name}_b']
        self.num_examples += self.sd['batch_size'] 
        
    def on_epoch_end(self):
        
        last = self.total / self.num_examples
        self.sd[f'last_{self.metric_name}'] = last
        self.sd[f'{self.metric_name}s'].append(last)
        
        if self.on_tensorboard:
            self.sd['writer'].add_scalar(
                f'{self.group}/{self.metric_name}', 
                self.sd[f'last_{self.metric_name}'], 
                self.sd['epoch']
            )

In [21]:
class MetricsSaver(Callback):
    """
    Log metric values over time into a csv that can be loaded and visualized within a jupyter notebook.
    
    Depend on MetricRecorder to work properly.
    """
    _order=2
    
    def __init__(self, metrics_to_log:list, csv_path:str, overwrite:bool):
        """
        :param metrics_to_log: a list of names of the metrics to log
            - make sure that an accumulator for each metric is available in self.sd
            - make sure to add an 's' to each metric name
            - do not add 'epochs' to this list, since it will be obvious during plotting
        """
        self.metrics_to_log = metrics_to_log
        self.csv_path = csv_path
        ifoverwrite(overwrite, csv_path, 'path')
    
    def on_train_end(self):
        loss_df = pd.DataFrame(np.array([self.sd[m] for m in self.metrics_to_log]).T)
        loss_df.columns = self.metrics_to_log
        loss_df.to_csv(self.csv_path)

In [22]:
class GenLogger(Callback):
    """Visualize generated arrays for generative models."""
    _order=1
    
    def __init__(self, gen_name:str, group:str):
        self.gen_name = gen_name
        self.group = group

    def on_epoch_end(self):        
        self.sd['writer'].add_images(
            f'{self.group}/{self.gen_name}', 
            self.sd[self.gen_name],
            global_step=self.sd['epoch']
        )

In [23]:
# class GensSaver(Callback):
#     """Log (in a directory) generated arrays for generative models."""
#     _order=2
#     def __init__(self, )
    
#     for gen in self.sd[gen_name]:  # self.sd[gen_name] is of shape (N, C, H, W)
#             scipy.misc.imsave('outfile.jpg', image_array)
#             plt.savefig(gen.detach()., dpi=100)

In [24]:
class ModelSaver(Callback):
    _order=2
    
    def __init__(self, model_path:str, overwrite:bool): 
        self.model_path = model_path
        ifoverwrite(overwrite, model_path, 'path')
    
    def on_train_end(self):
        torch.save(self.sd['model'].state_dict(), self.model_path)

In [25]:
def fake_train(cb_handler):
    cb_handler.on_train_begin()
    for i in range(255):
        cb_handler.on_epoch_begin()
        for j in range(10):
            cb_handler.on_batch_begin()
            cb_handler.on_loss_begin()
            cb_handler.sd.update({
                'model':torch.nn.Linear(5, 10),
                'epoch':i, 
                'batch_size':64, 
                'bce_loss_b':np.random.random()*64
            })
            cb_handler.on_backward_begin()
            cb_handler.on_backward_end()
            cb_handler.on_step_end()
            cb_handler.on_batch_end()
        cb_handler.sd.update({'gen':torch.ones((5, 3, 2, 2))*cb_handler.sd['epoch']})
        cb_handler.on_epoch_end()
    cb_handler.on_train_end()

In [26]:
# these may serve as arguments for VAETrainer
# caution: overwrites do not show up immediately on tensorboard; you may need to restart it in terminal :<
exp_name = 'testing_callbacks'
trial = 4
overwrite_vis = True
overwrite_csv = True
overwrite_pth = True

In [27]:
cb_handler = CallbackHandler([
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}', overwrite=overwrite_vis),  # automatically mkdir
    MetricLogger(metric_name='bce_loss', group='train', on_tensorboard=True),
    MetricsSaver(metrics_to_log=['bce_losss'], csv_path=f'training_csv/{exp_name}/{trial}.csv', overwrite=overwrite_csv),
    GenLogger(gen_name='gen', group='group1'),
    ModelSaver(model_path=f'trained_models/{exp_name}/{trial}.pth', overwrite=overwrite_pth)
])
fake_train(cb_handler)

## VAETrainer
container for a training loop involving a Learner object, custom callbacks and a CallbackHandler object

In [28]:
class VAETrainer(CallbackHandler):
    
    def __init__(self, learn, cbs):
        self.learn = learn
        self.cbs = cbs
        
    def train(self, num_epochs):
        self.on_train_begin()  # create empty accumulators
        print('finished on train begin')
        for epoch in range(num_epochs):
            self.on_epoch_begin()
            print('finished on epoch begin')
            for xb, yb in self.learn.train_data:
                
                self.on_batch_begin()
                print('finished on batch begin')
                
                recon, mu, logvar = self.learn.model(xb)
                
                self.on_loss_begin()
                print('finished on loss begin')
                loss, bce, kld = self.learn.loss(recon, yb, mu, logvar)
                
                self.on_backward_begin()
                loss.backward()
                self.on_backward_end()
                self.learn.optim.step()
                self.on_step_end()
                
                self.sd.update({
                    'model':self.learn.model,
                    'epoch':epoch+1, 
                    'batch_size':int(xb.size(0)), 
                    'loss_b':float(loss), 
                    'bce_b':float(bce), 
                    'kld_b':float(kld)
                })
                
                self.on_batch_end()
            gens = self.learn.model.generate(n=10)
            self.sd.update({'gens':gens})
            self.on_epoch_end()  # calculate average loss per example, visualize metrics, perform validation if required
        self.on_train_end()  # close tensorboard writer, output csv of metrics

Load data as dataloaders.

In [29]:
from keras.datasets import mnist
from fast_train import DataPipeline

Using TensorFlow backend.


In [30]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [31]:
x_train = (x_train - x_train.mean()) / x_train.std()
x_train = np.expand_dims(x_train, axis=1)
np.random.seed(42)
np.random.shuffle(x_train)

In [32]:
x_train = x_train[:10000]

In [33]:
train_dl = DataPipeline.float_vae(x_train, bs=16)
del x_train

In [34]:
a, b = next(iter(train_dl))
print(a.shape, b.shape)

torch.Size([16, 1, 28, 28]) torch.Size([16, 1, 28, 28])


Design VAE.

In [35]:
from vae_designer import VAEDesigner
from custom_vae import VAEDesign
designer_on = False

In [36]:
if designer_on:
    dec_designer = VAEDesigner(input_shape=(1, 28, 28), up_sample=False)

In [37]:
if designer_on:
    print(dec_designer.design)

In [38]:
if designer_on:
    enc_designer = VAEDesigner(input_shape=(3, 5, 5), up_sample=True)

In [39]:
if designer_on:
    print(enc_designer.design)

In [40]:
if designer_on:
    vae_design = VAEDesign(
        down_sampler_design=dec_designer.design, 
        up_sampler_design=enc_designer.design, 
        h_dim=64, 
        z_dim=3, 
        unflatten_out_shape=(64, 1, 1)
    )

In [41]:
if designer_on:
    vae_design.save_as_json('designs/mnist_vae.json')

Instantiate a learner.

In [42]:
from custom_vae import get_vae_and_opt
from fast_train import Loss

In [43]:
vae, opt = get_vae_and_opt('designs/mnist_vae.json', dev='cpu')
learner = Learner(train_data=train_dl, model=vae, loss=Loss.float_loss_fn, optim=opt, valid_data=None)

In [44]:
vae.decoder

Sequential(
  (unflatten): UnFlatten()
  (block0-convtranpose2d): ConvTranspose2d(3, 3, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1), bias=False)
  (block0-bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block0-relu): ReLU()
  (block1-convtranpose2d): ConvTranspose2d(3, 1, kernel_size=(4, 4), stride=(2, 2), bias=False)
  (block1-bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [45]:
a = torch.zeros((10, 1, 28, 28)).double()

In [None]:
vae.encoder(input=a).shape

In [11]:
vae.encoder(next(iter(train_dl))[0][:10]).shape

NameError: name 'train_dl' is not defined

Instantiate callbacks and a trainer.

In [60]:
exp_name = 'mnist'
trial = 1
overwrite_vis = True
overwrite_csv = True
overwrite_pth = True

In [61]:
vae_cbs = [
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}', overwrite=overwrite_vis),  # automatically mkdir
    MetricLogger(metric_name='loss', group='train', on_tensorboard=True),
    MetricLogger(metric_name='bce', group='train', on_tensorboard=True),
    MetricLogger(metric_name='kld', group='train', on_tensorboard=True),
    MetricsSaver(metrics_to_log=['losss', 'bces', 'klds'], csv_path=f'training_csv/{exp_name}/{trial}.csv', overwrite=overwrite_csv),
    GenLogger(gen_name='gen', group='group1'),
    ModelSaver(model_path=f'trained_models/{exp_name}/{trial}.pth', overwrite=overwrite_pth)
]

In [63]:
vae_cbs = []

In [64]:
vae_trainer = VAETrainer(learn=learner, cbs=vae_cbs)

In [1]:
vae_trainer.train(num_epochs=1)

NameError: name 'vae_trainer' is not defined