# Distilling a Neural Network into Soft Decision Tree

* Implementation based on [[Frosst & Hinton, 2017](http://arxiv.org/abs/1711.09784)]

## Imports

In [1]:
import numpy as np
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.keras.callbacks import EarlyStopping, Callback

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
from importlib import import_module

distillnntree = import_module('distill-nn-tree.models')
distillnntreeutils = import_module('distill-nn-tree.models.utils')

ConvNet, SoftBinaryDecisionTree = distillnntree.ConvNet, distillnntree.SoftBinaryDecisionTree
brand_new_tfsession, draw_tree = distillnntreeutils.brand_new_tfsession, distillnntreeutils.draw_tree


sess = brand_new_tfsession()




## Dataset

In [3]:
# load MNIST data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# add channel dim
x_train, x_test = x_train[..., np.newaxis], x_test[..., np.newaxis]

# hold out last 10000 training samples for validation
x_valid, y_valid = x_train[-10000:], y_train[-10000:]
x_train, y_train = x_train[:-10000], y_train[:-10000]

print(x_train.shape, y_train.shape, x_valid.shape, y_valid.shape, x_test.shape, y_test.shape)

(50000, 28, 28, 1) (50000,) (10000, 28, 28, 1) (10000,) (10000, 28, 28, 1) (10000,)


In [4]:
# retrieve image and label shapes from training data
img_rows, img_cols, img_chans = x_train.shape[1:]
n_classes = np.unique(y_train).shape[0]

print(img_rows, img_cols, img_chans, n_classes)

28 28 1 10


In [5]:
# convert labels to 1-hot vectors
y_train = tf.keras.utils.to_categorical(y_train, n_classes)
y_valid = tf.keras.utils.to_categorical(y_valid, n_classes)
y_test = tf.keras.utils.to_categorical(y_test, n_classes)

print(y_train.shape, y_valid.shape, y_test.shape)

(50000, 10) (10000, 10) (10000, 10)


In [6]:
# normalize inputs and cast to float
x_train = (x_train / np.max(x_train)).astype(np.float32)
x_valid = (x_valid / np.max(x_valid)).astype(np.float32)
x_test = (x_test / np.max(x_test)).astype(np.float32)

## Neural Network

In [7]:
nn = ConvNet(img_rows, img_cols, img_chans, n_classes)
nn.maybe_train(data_train=(x_train, y_train),
               data_valid=(x_valid, y_valid),
               batch_size=16, epochs=12)
nn.evaluate(x_train, y_train)

Loading trained model from assets/nn-model.hdf5.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
accuracy: 99.90% | loss: 0.0032113500596674203


In [8]:
nn.model.layers[0].input_shape

[(None, 28, 28)]

In [9]:
nn.evaluate(x_valid, y_valid)
nn.evaluate(x_test, y_test)

accuracy: 99.14% | loss: 0.03600480322143594
accuracy: 99.20% | loss: 0.028137104455670123


### Extraction of soft labels for distillation

In [10]:
y_train_soft = nn.predict(x_train)
y_train_soft.shape

(50000, 10)

## Binary Soft Decision Tree

Flatten dataset in advance

In [11]:
x_train_flat = x_train.reshape((x_train.shape[0], -1))
x_valid_flat = x_valid.reshape((x_valid.shape[0], -1))
x_test_flat = x_test.reshape((x_test.shape[0], -1))

# import matplotlib.pyplot as plt
# %matplotlib inline
# plt.imshow(x_test_flat.reshape((x_test_flat.shape[0], img_rows, img_cols))[1])

x_train_flat.shape, x_valid_flat.shape, x_test_flat.shape

((50000, 784), (10000, 784), (10000, 784))

### Hyperparameters
* `tree_depth`: as denoted in the [[paper](https://arxiv.org/pdf/1711.09784.pdf)], depth is in terms of inner nodes (excluding leaves / indexing depth from `0`)
* `penalty_strength`: regularization penalty strength
* `penalty_decay`: regularization penalty decay: paper authors found 0.5 optimal (note that $2^{-d} = 0.5^d$ as we use it)
* `ema_win_size`: scaling factor to the "default size of the window" used to calculate moving averages (growing exponentially with depth) of node and path probabilities
* `inv_temp`: scale logits of inner nodes to "avoid very soft decisions" [[paper](https://arxiv.org/pdf/1711.09784.pdf)]
    * pass `0` to indicate that this should be a learned parameter (single scalar learned to apply to all nodes in the tree)
* `learning_rate`: hopefully no need to explain, but let's be cool and use [Karpathy constant](https://www.urbandictionary.com/define.php?term=Karpathy%20Constant) ([source](https://twitter.com/karpathy/status/801621764144971776)) :D as default in `tree.__init__()`
* `batch_size`: we use a small one, because with increasing depth and thus amount of leaf bigots, larger batch sizes cause their loss terms to be scaled down too much by averaging, which results in poor optimization properties

In [12]:
n_features = img_rows * img_cols * img_chans
tree_depth = 4
penalty_strength = 1e+1
penalty_decay = 0.25
ema_win_size = 1000
inv_temp = 0.01
learning_rate = 5e-03
batch_size = 4

### Regular training with hard labels

In [13]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'g

In [35]:
epochs = 40

es = EarlyStopping(monitor='val_acc', patience=20, verbose=1)

'''If you wish to train your own model instead of loading one from checkpoint, remove the checkpoint.'''
# os.remove('assets/non-distilled/checkpoint')
# for f in glob.glob('assets/non-distilled/tree-model*'):
#     os.remove(f)

tree.maybe_train(
    sess=sess, data_train=(x_train_flat, y_train), data_valid=(x_valid_flat, y_valid),
    batch_size=batch_size, epochs=epochs, callbacks=[es])

Loading trained model from assets/non-distilled/tree-model.
assets/non-distilled/tree-model is not a valid checkpoint. Training from scratch.
Train on 50000 samples, validate on 10000 samples
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40
Saving trained model to assets/non-distilled/tree-model.


In [36]:
tree.evaluate(x=x_valid_flat, y=y_valid, batch_size=batch_size)
tree.evaluate(x=x_test_flat, y=y_test, batch_size=batch_size)

accuracy: 92.78% | loss: 7.908185776901245
accuracy: 92.74% | loss: 7.897307242202759


### Distillation: training with soft labels

In [37]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Built tree has 32 leaves out of 63 nodes


In [38]:
epochs = 40

es = EarlyStopping(monitor='val_acc', patience=20, verbose=1)

'''If you wish to train your own model instead of loading one from checkpoint, remove the checkpoint.'''
# os.remove('assets/distilled/checkpoint')
# for f in glob.glob('assets/distilled/tree-model*'):
#     os.remove(f)

tree.maybe_train(
    sess=sess, data_train=(x_train_flat, y_train_soft), data_valid=(x_valid_flat, y_valid),
    batch_size=batch_size, epochs=epochs, callbacks=[es], distill=True)

Loading trained model from assets/distilled/tree-model.


NotFoundError: FindFirstFile failed for: assets/distilled : The system cannot find the path specified.
; No such process

In [None]:
tree.evaluate(x=x_valid_flat, y=y_valid, batch_size=batch_size)
tree.evaluate(x=x_test_flat, y=y_test, batch_size=batch_size)

### Visualizing learned parameters

In [None]:
draw_tree(sess, tree, img_rows, img_cols, img_chans)

#### How to read the visual

Exactly as in the [[paper](https://arxiv.org/pdf/1711.09784.pdf)]:
* Number **below** any **leaf** denotes `argmax()` of learned distribution, thus final static **prediction** of the (bigot, not expert!) leaf.
* Numbers **above** any **inner node** denote the **set of possible predictions** in the sub-tree of the given node.

### Visualizing decision path

In [None]:
digit = 9

# get (reproducibly) pseudo-random example of chosen digit
np.random.seed(0)
sample_index = np.random.choice(np.where(np.argmax(y_test, axis=1)==digit)[0])
input_img = x_test[sample_index]

In [None]:
draw_tree(sess, tree, img_rows, img_cols, img_chans, input_img=input_img)

#### How to read the visual

* The <span style="color:green">**maximum probability path**</span> leading **to final prediction** is now denoted by <span style="color:green"> **green arrows**</span>
* Number **below** any given **inner node** on this <span style="color:green">**path**</span> denotes the **pre-activation logit** $ = (\beta (\mathbf{xw}_i + b_i))$.
    * This is basically just a **biased** ($b_i$) and **scaled** ($\beta$) **correlation** of **input** ($\mathbf{x}$) with the given **mask** ($\mathbf{w}_i$).
    * From the definition of $\sigma$ activation function, the choice of branch breaks around `0`.
    * From the definition of **branching** in the [[paper](https://arxiv.org/pdf/1711.09784.pdf)], **negative** correlations branch **to the left**, while **positive** correlations branch **to the right**.

<img src="assets/img/branching.png" width="35%"/>

In [None]:
draw_tree(sess, tree, img_rows, img_cols, img_chans, input_img=input_img, show_correlation=True)

#### How to read the visual

* On the <span style="color:green">**maximum probability path**</span> there are now **correlations** of the **input image** with the **node masks**.
* The **homogeneous area** gives a frame of reference for color of `0`s.
    * It always corresponds to the **black area in the input image**, but due to lack of normalization (yes, I'm the lazy one here), it ends up as different shade of gray in each subplot.
    * All **lighter pixels** from this correspond to **positive correlation coefficients**.
    * All **darker pixels** correspond to **negative correlation coefficients**.

_Note: In the last input-masked kernel on the path to prediction, notice how model recognizes `9`s from `7`s._

To save the inference example as animation, run the cell below.

In [None]:
if not os.path.isdir('assets/img/infer'):
    os.mkdir('assets/img/infer')

draw_tree(sess, tree, img_rows, img_cols, img_chans,
          input_img=input_img,
          savepath='assets/img/infer/0.png')
draw_tree(sess, tree, img_rows, img_cols, img_chans,
          input_img=input_img, show_correlation=True,
          savepath='assets/img/infer/1.png')

!convert -delay 100 -loop 0 assets/img/infer/*.png assets/img/infer.gif

### Capturing the progress of learning

In [None]:
if not os.path.isdir('assets/img/epoch'):
    os.mkdir('assets/img/epoch')

if not os.path.isdir('assets/img/sample'):
    os.mkdir('assets/img/sample')

In [None]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

tree.initialize_variables(sess, x_train_flat, batch_size)

In [None]:
class ModelImageSaver(Callback):
    def __init__(self, display, limit):
        self.seen = 0
        self.display = display
        self.limit = limit
        
    def on_train_begin(self, logs={}):
        draw_tree(sess, tree, img_rows, img_cols, img_chans,
                  savepath='assets/img/epoch/{:04}.png'.format(0))
        draw_tree(sess, tree, img_rows, img_cols, img_chans,
                  savepath='assets/img/sample/{:07}.png'.format(0))
        
    def on_epoch_end(self, epoch, logs={}):
        draw_tree(sess, tree, img_rows, img_cols, img_chans,
                  savepath='assets/img/epoch/{:04}.png'.format(epoch+1))

    def on_batch_end(self, batch, logs={}):
        self.seen += logs.get('size', 0)
        if self.seen % self.display == 0 and self.seen <= self.limit:
            draw_tree(sess, tree, img_rows, img_cols, img_chans,
                      savepath='assets/img/sample/{:07}.png'.format(self.seen))

image_saver = ModelImageSaver(1000, 250000)
# save image after each 1000th training example
# save max 250 images (corresponds to first 5 training epochs)

tree.model.fit(x=x_train_flat, y=y_train_soft, validation_data=(x_valid_flat, y_valid),
               batch_size=batch_size, epochs=40, callbacks=[image_saver]);

#### Compiling snapshots into animation
**Note**: converting captured series of PNG images into a GIF animation with `makegif.sh` requires `bash` environment with `convert` CLI tool available.

##### Epoch-wise compilation

In [None]:
!./makegif.sh epoch

![epoch.gif](assets/img/epoch.gif)

##### Sample-wise compilation

In [None]:
!./makegif.sh sample

![sample.gif](assets/img/sample.gif)

## Elaborating

![deeper.jpg](assets/img/deeper.jpg)

By now, you should know what's coming...

In [None]:
tree_depth = 5

In [None]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

tree.initialize_variables(sess, x_train_flat, batch_size)

In [None]:
tree.model.fit(x=x_train_flat, y=y_train_soft, validation_data=(x_valid_flat, y_valid),
               batch_size=batch_size, epochs=3);

# os.mkdir('assets/depth-{}'.format(tree_depth))
# tree.save_variables(sess, 'assets/depth-{}/tree-model'.format(tree_depth))

In [None]:
draw_tree(sess, tree, img_rows, img_cols, img_chans)

Sorry, but deeper than this was not so visually appealing and would take much longer to train to a reasonable performance to even motivate examination.

# Final word

If you're reading this, I believe you are interested in this implementation, so please don't hesitate to **try it yourself** :)

* tune hyperparameters of the tree model
    * try out different depths and penalty parameters (strength, decay)
    * implement dynamic inverse temperature ($\beta$), scheduled as a function of training step / epoch
* try out different dataset, the approach is generic enough!

If you get any interesting results with this implementation, feel free to share them as an [issue](https://github.com/lmartak/distill-nn-tree/issues). Also feel free to improve this repo by submitting a [PR](https://github.com/lmartak/distill-nn-tree/pulls) or just making your own [fork](https://github.com/lmartak/distill-nn-tree/network/members).


If you feel adventurous, you could try:
* improve `draw_tree`'s correlation mode by normalizing the shade of gray around fixed-`0` color shade
* add similar notebook with whole training, distillation & evaluation lifecycle on different dataset (e.g. `cifar-10.ipynb` for [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset)
    * This would probably require colorful masks and some experimenting with their normalization for the purposes of visualization, but could be fun!