Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
129 lines (102 sloc) 5.56 KB

Adaptive Neural Trees

MIT License

This repository contains our PyTorch implementation of Adaptive Neural Trees (ANTs).

The code was written by Ryutaro Tanno and supported by Kai Arulkumaran.

Paper (ICML'19) | Video from London ML meetup


  • Linux or macOS
  • Python 2.7
  • Anaconda >= 4.5
  • CPU or NVIDIA GPU + CUDA 8.0


  • Clone this repo:
git clone
cd AdaptiveNeuralTrees
  • (Optional) create a new Conda environment and activate it:
conda create -n ANT python=2.7
source activate ANT
  • Run the following to install required packages.
bash ./


An example command for training/testing an ANT is given below.

python --experiment test_ant_cifar10  #name of experiment \
               --subexperiment myant  #name of subexperiment \
               --dataset cifar10   #dataset \
                # Model details:    \
               --router_ver 3        #type of router module \
               --router_ngf 128      #no. of kernels in routers \
               --router_k 3          #spatial size of kernels in routers \
               --transformer_ver 5   #type of transformer module \
               --transformer_ngf 128 #no. of kernels in transformers \
               --transformer_k 3     #spatial size of kernels in transformers \
               --solver_ver 6        #type of solver module \
               --batch_norm          #apply batch-norm \
               --maxdepth 10         #maximum depth of the tree-structure \
                # Training details: \
               --batch-size 512    #batch size \
               --augmentation_on   #apply data augmentation \
               --scheduler step_lr #learning rate scheduling \
               --criteria avg_valid_loss # splitting criteria
               --epochs_patience 5 #no. of patience per node for growth phase \
               --epochs_node 100   #max no. of epochs per node for growth phase \
               --epochs_finetune 200 #no. of epochs for fine-tuning phase \
               # Others: \
               --seed 0            #randomisation seed
               --num_workers 0     #no. of CPU subprocesses used for data loading \
               --visualise_split  # save the tree structure every epoch \

The model configurations and optimisation trajectory (e.g value of train/validation loss at each time point) are saved in records.jason in the directory ./experiments/dataset/experiment/subexperiment/checkpoints. Similarly, tree structure and best trained model are saved as tree_structures.json and model.pth, respectively under the same directory. If the visualisation option --visualise_split is used, the tree architecture of the ANT is saved in the PNG format in the directory ./experiments/dataset/experiment/subexperiment/cfigures.

By default, the average classification accuracy is also computed on train/valid/test sets for every epoch and saved in records.jason file, so running would suffice for both training and testing an ANT of particular configurations.

Jupyter Notebooks

We have also included two Jupter notebooks ./notebooks/example_mnist.ipynb and ./notebooks/example_cifar10.ipynb, which illustrate how this repository can be used to train ANTs on MNIST and CIFAR-10 image recognition datasets.

Primitive modules

Defining an ANT amounts to specifying the forms of primitive modules: routers, transformers and solvers. The table below provides the list of currently implemented primitive modules. You can try any combination of three to construct an ANT.

Type Router Transformer Solver
1 1 x Conv + GAP + Sigmoid Identity function Linear classifier
2 1 x Conv + GAP + 1 x FC 1 x Conv MLP with 2 hidden layers
3 2 x Conv + GAP + 1 x FC 1 x Conv + 1 x MaxPool MLP with 1 hidden layer
4 MLP with 1 hidden layer Bottleneck residual block (He et al., 2015) GAP + 2 FC layers + Softmax
5 GAP + 2 x FC layers (Veit et al., 2017) 2 x Conv + 1 x MaxPool MLP with 1 hidden layer in AlexNet (layers-80sec.cfg)
6 1 x Conv + GAP + 2 x FC Whole VGG13 architecture (without the linear layer) GAP + 1 FC layers + Softmax

For the detailed definitions of respective modules, please see and


If you use this code for your research, please cite our ICML paper:

  title={Adaptive Neural Trees},
  author={Tanno, Ryutaro and Arulkumaran, Kai and Alexander, Daniel and Criminisi, Antonio and Nori, Aditya},
  booktitle={Proceedings of the 36th International Conference on Machine Learning (ICML)},


I would like to thank Daniel C. Alexander at University College London, UK, Antonio Criminisi at Amazon Research, and Aditya Nori at Microsoft Research Cambridge for their valuable contributions to this paper.

You can’t perform that action at this time.