In [6]:
#export
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
import os
import shutil

In [None]:
import sys
sys.path.append('../modules')

In [2]:
def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

## Experiments with `state_dict`

In [72]:
class Callback():
    
    state_dict = {}

In [73]:
class CallbackA(Callback): 
    
    def print_state_dict(self):
        print(self.state_dict)

In [74]:
class CallbackB(Callback): pass

The state_dict in CallbackA and CallbackB are the same object; it got inherited.

In [75]:
CallbackA.state_dict == CallbackB.state_dict

True

When CallbackB makes a change to its state_dict, CallbackA also perceives that change! This again confirms that the same dict got inherited.

In [8]:
CallbackB.state_dict.update({'apple':1})

In [9]:
CallbackA.state_dict

{'apple': 1}

Instances can also read and write the state_dict shared by their parent classes.

In [10]:
cb_a, cb_b = CallbackA(), CallbackB()

In [11]:
cb_a.state_dict.update({'pear':2})

In [12]:
cb_a.print_state_dict()

{'apple': 1, 'pear': 2}


In [13]:
cb_b.state_dict

{'apple': 1, 'pear': 2}

In [14]:
CallbackA.state_dict

{'apple': 1, 'pear': 2}

## Learner
container for data, model, loss_fn, optim

In [1]:
class Learner():
    def __init__(self, train_data, model, loss, optim, valid_data=None):
        self.train_data, self.model, self.loss, self.optim, self.valid_data = train_data, model, loss, optim, valid_data

## Callback base class, custom callbacks, CallbackHandler

In [16]:
#export
class Callback(): 
    sd = {}
    def on_train_begin(self): pass  # make sure to reset self.sd because it is a class attribute
    def on_epoch_begin(self): pass
    def on_batch_begin(self): pass
    def on_loss_begin(self): pass
    def on_backward_begin(self): pass
    def on_backward_end(self): pass
    def on_step_end(self): pass
    def on_batch_end(self): pass
    def on_epoch_end(self): pass
    def on_train_end(self): pass

In [17]:
#export
class CallbackHandler(Callback): 
    
    def __init__(self, cbs):
        self.cbs = cbs
    
    def __call__(self, cb_category:str):
        self.cbs = sorted(self.cbs, key=lambda cb : cb._order)
        for cb in self.cbs: getattr(cb, cb_category)()
    
    def on_train_begin(self): self('on_train_begin')
        
    def on_epoch_begin(self): self('on_epoch_begin')
    
    def on_batch_begin(self): self('on_batch_begin')
        
    def on_loss_begin(self): self('on_loss_begin')
        
    def on_backward_begin(self): self('on_backward_begin')
        
    def on_backward_end(self): self('on_backward_end')
        
    def on_step_end(self): self('on_step_end')
        
    def on_batch_end(self): self('on_batch_end')
    
    def on_epoch_end(self): self('on_epoch_end')
        
    def on_train_end(self): self('on_train_end')

In [18]:
#export
def ifoverwrite(overwrite:bool, file_navigator:str, file_navigator_type:str)->None:
    
    if overwrite:
        
        if file_navigator_type == 'path':
            file_dir = ''.join([folder + '/' for folder in file_navigator.split('/')[:-1]])
            if not (os.path.isfile(file_navigator_type) or os.path.isdir(file_dir)):
                os.makedirs(file_dir, exist_ok=True)
            
        elif file_navigator_type == 'dir':
            file_dir = file_navigator
            if os.path.isdir(file_dir):  # so that overwrite also works when there's nothing to overwrite
                shutil.rmtree(file_dir)
                os.makedirs(file_dir, exist_ok=False)
            
    else:
        assert not (os.path.isfile(file_navigator) or os.path.isdir(file_navigator)), \
        AssertionError(f'{file_navigator} already exists. To overwrite it, pass True to argument `overwrite`.')

In [19]:
#export
class TensorboardCreator(Callback):
    _order=0
    
    def __init__(self, log_dir, overwrite:bool):
        self.log_dir = log_dir
        ifoverwrite(overwrite, log_dir, 'dir')
    
    def on_train_begin(self):
        self.sd.update({'writer':SummaryWriter(log_dir=self.log_dir, flush_secs=2, max_queue=2)})
        
    def on_train_end(self):
        self.sd['writer'].flush()
        self.sd['writer'].close()
        del self.sd['writer']

In [20]:
#export
class MetricLogger(Callback):
    """Log (in a list) and visualize metric values over time."""
    _order=1
    
    def __init__(self, metric_name:str, group:str, on_tensorboard:bool):
        
        self.metric_name = metric_name
        self.group = group
        self.on_tensorboard = on_tensorboard

    def on_train_begin(self):
        self.sd[f'last_{self.metric_name}'] = None
        self.sd[f'{self.metric_name}s'] = []

    def on_epoch_begin(self):
        self.total = 0
        self.num_examples = 0
    
    def on_batch_end(self):
        self.total += self.sd[f'{self.metric_name}_b']
        self.num_examples += self.sd['batch_size'] 
        
    def on_epoch_end(self):
        
        last = self.total / self.num_examples
        self.sd[f'last_{self.metric_name}'] = last
        self.sd[f'{self.metric_name}s'].append(last)
        
        if self.on_tensorboard:
            self.sd['writer'].add_scalar(
                f'{self.group}/{self.metric_name}', 
                self.sd[f'last_{self.metric_name}'], 
                self.sd['epoch']
            )

In [21]:
#export
class MetricsSaver(Callback):
    """
    Log metric values over time into a csv that can be loaded and visualized within a jupyter notebook.
    
    Depend on MetricRecorder to work properly.
    """
    _order=2
    
    def __init__(self, metrics_to_save:list, csv_path:str, overwrite:bool):
        """
        :param metrics_to_save: a list of names of the metrics to log
            - make sure that an accumulator for each metric is available in self.sd
            - make sure to add an 's' to each metric name
            - do not add 'epochs' to this list, since it will be obvious during plotting
        """
        self.metrics_to_save = metrics_to_save
        self.csv_path = csv_path
        ifoverwrite(overwrite, csv_path, 'path')
    
    def on_train_end(self):
        loss_df = pd.DataFrame(np.array([self.sd[m] for m in self.metrics_to_save]).T)
        loss_df.columns = self.metrics_to_save
        loss_df.to_csv(self.csv_path)

In [2]:
class GenLogger(Callback):
    """Visualize generated arrays for generative models."""
    _order=1
    
    def __init__(self, gen_name:str, group:str):
        self.gen_name = gen_name
        self.group = group

    def on_epoch_end(self):        
        self.sd['writer'].add_images(
            f'{self.group}/{self.gen_name}', 
            self.sd[self.gen_name],
            global_step=self.sd['epoch']
        )

NameError: name 'Callback' is not defined

In [4]:
#export
class Debugger(Callback):
    _order=999
    def __init__(self, on:bool): self.on = on
    def on_train_begin(self): 
        if self.on: print('finished on_train_begin')
    def on_epoch_begin(self): 
        if self.on: print('finished on_epoch_begin')
    def on_batch_begin(self): 
        if self.on: print('finished on_batch_begin')
    def on_loss_begin(self): 
        if self.on: print('finished on_loss_begin')
    def on_backward_begin(self): 
        if self.on: print('finished on_backward_begin')
    def on_backward_end(self): 
        if self.on: print('finished on_backward_end')
    def on_step_end(self): 
        if self.on: print('finished on_step_end')
    def on_batch_end(self): 
        if self.on: print('finished on_batch_end')
    def on_epoch_end(self): 
        if self.on: print('finished on_epoch_end')
    def on_train_end(self): 
        if self.on: print('finished on_train_end')

NameError: name 'Callback' is not defined

### MetricsPrinter

In [5]:
#export
class MetricsPrinter(Callback):
    _order=2
    
    def __init__(self, metrics_to_print):
        self.metrics_to_print = metrics_to_print
    
    def on_epoch_end(self):
        for m in self.metrics_to_print:
            print(f'{m}: {self.sd[m]}', end='|')
        print('')

NameError: name 'Callback' is not defined

In [25]:
# class GensSaver(Callback):
#     """Log (in a directory) generated arrays for generative models."""
#     _order=2
#     def __init__(self, )
    
#     for gen in self.sd[gen_name]:  # self.sd[gen_name] is of shape (N, C, H, W)
#             scipy.misc.imsave('outfile.jpg', image_array)
#             plt.savefig(gen.detach()., dpi=100)

In [26]:
#export
class ModelSaver(Callback):
    _order=2
    
    def __init__(self, model_path:str, overwrite:bool): 
        self.model_path = model_path
        ifoverwrite(overwrite, model_path, 'path')
    
    def on_train_end(self):
        torch.save(self.sd['model'].state_dict(), self.model_path)

In [27]:
def fake_train(cb_handler):
    cb_handler.on_train_begin()
    for i in range(255):
        cb_handler.on_epoch_begin()
        for j in range(10):
            cb_handler.on_batch_begin()
            cb_handler.on_loss_begin()
            cb_handler.sd.update({
                'model':torch.nn.Linear(5, 10),
                'epoch':i, 
                'batch_size':64, 
                'bce_loss_b':np.random.random()*64
            })
            cb_handler.on_backward_begin()
            cb_handler.on_backward_end()
            cb_handler.on_step_end()
            cb_handler.on_batch_end()
        cb_handler.sd.update({'gen':torch.ones((5, 3, 2, 2))*cb_handler.sd['epoch']})
        cb_handler.on_epoch_end()
    cb_handler.on_train_end()

In [28]:
# these may serve as arguments for VAETrainer
# caution: overwrites do not show up immediately on tensorboard; you may need to restart it in terminal :<
exp_name = 'testing_callbacks'
trial = 4
overwrite_vis = True
overwrite_csv = True
overwrite_pth = True

In [29]:
cb_handler = CallbackHandler([
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}', overwrite=overwrite_vis),  # automatically mkdir
    MetricLogger(metric_name='bce_loss', group='train', on_tensorboard=True),
    MetricsSaver(metrics_to_save=['bce_losss'], csv_path=f'training_csv/{exp_name}/{trial}.csv', overwrite=overwrite_csv),
    GenLogger(gen_name='gen', group='group1'),
    ModelSaver(model_path=f'trained_models/{exp_name}/{trial}.pth', overwrite=overwrite_pth)
])
fake_train(cb_handler)

## VAETrainer
container for a training loop involving a Learner object, custom callbacks and a CallbackHandler object

In [30]:
def pixelwise_loss(input_image:torch.DoubleTensor, output_image:torch.DoubleTensor):
    return ((output_image - input_image) ** 2).sum() / input_image.size(0)

def kld_loss(logvars:torch.DoubleTensor, mus:torch.DoubleTensor):
    return -0.5 * torch.mean(1 + logvars - mus.pow(2) - logvars.exp())

In [31]:
class VAETrainer(CallbackHandler):
    
    def __init__(self, learn, cbs):
        self.learn = learn
        self.cbs = cbs
        
    def train(self, num_epochs):
        self.on_train_begin()  # create empty accumulators
        for epoch in range(num_epochs):
            self.on_epoch_begin()
            for xb, yb in log_progress(self.learn.train_data, name=f'epoch: {epoch+1}'):
                
                self.on_batch_begin()
                
                recon, mu, logvar = self.learn.model(xb)
                
                self.on_loss_begin()
                
                bce = pixelwise_loss(recon, yb)
                kld = kld_loss(logvar, mu)
                loss = bce + kld
                
#                 loss, bce, kld = self.learn.loss(recon, yb, mu, logvar)
#                 loss, bce, kld = loss / recon.size(0), bce / recon.size(0), kld / (recon.size(0) * 64)
                
                self.on_backward_begin()
                loss.backward()
                self.on_backward_end()
                self.learn.optim.step()
                self.on_step_end()
                
                self.sd.update({
                    'model':self.learn.model,
                    'epoch':epoch+1, 
                    'batch_size':int(xb.size(0)), 
                    'loss_b':float(loss), 
                    'bce_b':float(bce), 
                    'kld_b':float(kld)
                })
                
                self.on_batch_end()
                
            self.sd.update({'orgs': xb, 'recons': recon})
                
            gens = self.learn.model.generate(n=64)
            gens = torch.cat((gens, gens, gens), axis=1)
            self.sd.update({'gens':gens})
            
            self.on_epoch_end()  # calculate average loss per example, visualize metrics, perform validation if required
        self.on_train_end()  # close tensorboard writer, output csv of metrics

Load data as dataloaders.

In [32]:
from keras.datasets import mnist
from fast_train import DataPipeline

Using TensorFlow backend.


In [33]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [34]:
x_train = x_train / x_train.max()
x_train = np.expand_dims(x_train, axis=1)
np.random.seed(42)
np.random.shuffle(x_train)

In [35]:
assert x_train.min() >= 0

In [36]:
train_dl = DataPipeline.float_vae(x_train, bs=64)
del x_train

Design VAE.

In [37]:
from vae_designer import VAEDesigner
from custom_vae import VAEDesign
designer_on = False

In [38]:
if designer_on:
    enc_designer = VAEDesigner(input_shape=(1, 28, 28), up_sample=False)

In [39]:
if designer_on:
    print(enc_designer.design)

In [40]:
if designer_on:
    dec_designer = VAEDesigner(input_shape=(64, 1, 1), up_sample=True)

In [41]:
if designer_on:
    print(dec_designer.design)

In [42]:
if designer_on:
    vae_design = VAEDesign(
        down_sampler_design=enc_designer.design, 
        up_sampler_design=dec_designer.design, 
        h_dim=64, 
        z_dim=10, 
        unflatten_out_shape=(64, 1, 1)
    )

In [43]:
if designer_on:
    vae_design.up_sampler_design['final_activation'] = 'sigmoid'

In [44]:
if designer_on:
    vae_design.save_as_json('designs/mnist_vae.json')

In [45]:
if designer_on:
    vae_design.up_sampler_design

In [46]:
designer_on = False

Instantiate a learner.

In [47]:
from custom_vae import get_vae_and_opt
from fast_train import Loss

In [48]:
# import torch
# print(torch.__version__)

# from platform import python_version

# print(python_version())

In [49]:
vae, opt = get_vae_and_opt('designs/mnist_vae.json', dev='cuda')
learner = Learner(train_data=train_dl, model=vae, loss=Loss.binary_loss_fn, optim=opt, valid_data=None)

In [50]:
vae.encoder

Sequential(
  (block0-conv2d): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)
  (block0-bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block0-lrelu): LeakyReLU(negative_slope=0.2)
  (block1-conv2d): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)
  (block1-bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block1-lrelu): LeakyReLU(negative_slope=0.2)
  (block2-conv2d): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)
  (block2-bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block2-None): LeakyReLU(negative_slope=0.2)
  (flatten): Flatten()
)

In [51]:
vae.decoder

Sequential(
  (unflatten): UnFlatten()
  (block0-convtranpose2d): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1), bias=False)
  (block0-bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block0-relu): ReLU()
  (block1-convtranpose2d): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1), bias=False)
  (block1-bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block1-relu): ReLU()
  (block2-convtranpose2d): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), bias=False)
  (block2-sigmoid): Sigmoid()
)

Running some tests.

In [52]:
xs = torch.zeros((10, 1, 28, 28)).double().to('cpu')

In [53]:
xs = torch.zeros((10, 1, 28, 28)).double().to('cuda')

In [54]:
vae.encoder(input=xs).shape

torch.Size([10, 64])

In [55]:
zs = torch.zeros((10, 64, 1, 1)).double()

In [56]:
zs = torch.zeros((10, 64, 1, 1)).double().to('cuda')

In [57]:
vae.decoder(input=zs).shape

torch.Size([10, 1, 28, 28])

Re-instantiate a learner, instantiate callbacks and instantiate a trainer.

In [71]:
exp_name = 'mnist'
trial = 1
overwrite_vis = False
overwrite_csv = False
overwrite_pth = False
debugger_on = False

vae, opt = get_vae_and_opt('designs/mnist_vae.json', dev='cuda')
learner = Learner(train_data=train_dl, model=vae, loss=Loss.binary_loss_fn, optim=opt, valid_data=None)

vae_cbs = [
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}', overwrite=overwrite_vis),  # automatically mkdir
    MetricLogger(metric_name='loss', group='train', on_tensorboard=True),
    MetricLogger(metric_name='bce', group='train', on_tensorboard=True),
    MetricLogger(metric_name='kld', group='train', on_tensorboard=True),
    MetricsPrinter(metrics_to_print=['last_loss', 'last_bce', 'last_kld']),
    MetricsSaver(metrics_to_save=['losss', 'bces', 'klds'], csv_path=f'training_csv/{exp_name}/{trial}.csv', overwrite=overwrite_csv),
    GenLogger(gen_name='orgs', group='group1'),
    GenLogger(gen_name='recons', group='group1'),
    GenLogger(gen_name='gens', group='group1'),
    ModelSaver(model_path=f'trained_models/{exp_name}/{trial}.pth', overwrite=overwrite_pth),
    Debugger(on=debugger_on),
]

vae_trainer = VAETrainer(learn=learner, cbs=vae_cbs)

AssertionError: runs/mnist/1 already exists. To overwrite it, pass True to argument `overwrite`.

##### Log 19/12/18
- problem: VAE enters failure mode
- solution: change optimizer setting from `Adam lr=0.001` to `Adam lr=0.0002 betas=(0.5, 0.999)`

In [68]:
vae_trainer.train(num_epochs=10)

VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.6622072416856188|last_bce: 0.6307394116180236|last_kld: 0.031467830067595716|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4978093115949576|last_bce: 0.46360520531330074|last_kld: 0.03420410628165731|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4717706640839737|last_bce: 0.43742478002740526|last_kld: 0.034345884056569166|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.45531119041655566|last_bce: 0.4204466570540692|last_kld: 0.034864533362486266|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4472107952890796|last_bce: 0.4124553222601381|last_kld: 0.034755473028941435|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4429565607151172|last_bce: 0.40722832089018046|last_kld: 0.03572823982493656|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4422718756497001|last_bce: 0.4053122245813841|last_kld: 0.036959651068315995|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.43778772133716387|last_bce: 0.4003332442150552|last_kld: 0.03745447712210879|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.43068579393797|last_bce: 0.3940166780354025|last_kld: 0.0366691159025671|


VBox(children=(HTML(value=''), IntProgress(value=0, max=938)))

last_loss: 0.4280698620984104|last_bce: 0.3909913543666048|last_kld: 0.037078507731806015|
