Skip to content
No description, website, or topics provided.
Python Shell
Branch: master
Clone or download
Latest commit dbad201 Oct 3, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
data update Aug 16, 2019
tf_plus add license Aug 8, 2019
.fossa.yml update Aug 10, 2019
.gitignore update Aug 16, 2019
.travis.yml update Aug 10, 2019
LICENSE add license Aug 8, 2019
README.md Update README.md Oct 3, 2019
get_init_loss_train_lottery.py add license Aug 8, 2019
get_weight_init.py add license Aug 8, 2019
masked_layers.py add license Aug 8, 2019
masked_networks.py add license Aug 8, 2019
network_builders.py add license Aug 8, 2019
print_train_command.sh update Aug 16, 2019
print_train_lottery_iterative_command.sh update Aug 16, 2019
requirements.txt add files Aug 8, 2019
train.py add license Aug 8, 2019
train_lottery.py add license Aug 8, 2019
train_supermask.py add license Aug 8, 2019

README.md

Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask

Authors

Hattie Zhou, Janice Lan, Rosanne Liu, Jason Yosinski

Introduction

This codebase implements the experiments in Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask. This paper performs various ablation studies to shine light into the Lottery Tickets (LT) phenomenon observed by Frankle & Carbin in The Lottery Ticket Hypothesis: Finding Small, Trainable Neural Networks.

@inproceedings{zhou_2019_dlt
  title={Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask},
  author={Zhou, Hattie and Lan, Janice and Liu, Rosanne and Yosinski, Jason},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

For more on this project, see the Uber Eng Blog post.

Codebase structure

  • data/download_mnist.py, data/download_cifar10.py downloads MNIST/CIFAR10 data and splits it into train, val, and test, and saves them in the data folder as h5 files
  • get_weight_init.py computes various mask criteria
  • masked_layers.py defines new layer classes with masking options
  • masked_networks.py defines new layers and networks used in training Supermasks
  • network_builders.py defines the four network architecture evaluated in the paper (FC, Conv2, Conv4, Conv6)
  • train.py trains original unmasked networks
  • train_lottery.py reads in initial and final weights from a previously trained model, calculates the mask, and train a lottery style network
  • train_supermask trains a supermask directly using Bernoulli sampling
  • get_init_loss_train_lottery.py derives masks and calculates the initial accuracy of the masked network for various pruning percentages and mask criteria. Note that this uses a one-shot approach rather than an iterative approach.

This codebase uses the GitResultsManager package to keep track of experiments. See: https://github.com/yosinski/GitResultsManager

Example commands for running experiments

The following commands provide examples for running experiments in Deconstructing Lottery Tickets.

Train the original, unpruned network

  • Train a FC network (300-100-10) on MNIST: ./print_train_command.sh iter fc test 0 t

Alternative mask criteria experiments (using FC on MNIST and large final as an example)

  • Perform iterative LT training for a FC network on MNIST using large final mask criterion: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask none t

Mask-1 experiments

  • Randomly reinitialize weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reinit t

  • Randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reshuffle t

  • Convert the initial values of weights to a signed constant before randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask rand_signed_constant t

  • For versions that maintain the same sign, see signed_reinit, signed_reshuffle, and signed_constant.

Mask-0 experiments

  • Freeze pruned weights at initial values: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init none t

  • Freeze pruned weights that increased in magnitude at initial values: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_mask none t

  • Initialize weights that decreased in magnitude at 0, and freeze pruned weights at initial value: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_all none t

Supermask experiments

  • Evaluate the initial test accuracy of all alternative mask criteria: python get_init_loss_train_lottery.py --output_dir ./results/iter_lot_fc_orig/test_seed_0/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_lot --seed 0 --opt adam --lr 0.0012 --exp none --layer_cutoff 4,6 --prune_base 0.8,0.9 --prune_power 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

  • Train a Supermask directly: python train_supermask.py --output_dir ./results/iter_lot_fc_orig/learned_supermasks/run1/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_mask --opt sgd --lr 100 --num_epochs 2000 --print_every 220 --eval_every 220 --log_every 220 --save_weights --save_every 22000

You can’t perform that action at this time.