Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
hattie
committed
Aug 8, 2019
0 parents
commit ccfc10e
Showing
21 changed files
with
4,730 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask | ||
|
||
## Authors | ||
Hattie Zhou, Janice Lan, Rosanne Liu, Jason Yosinski | ||
|
||
## Introduction | ||
This codebase implements the experiments in [Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask](https://arxiv.org/abs/1905.01067). This paper performs various ablation studies to shine light into the Lottery Tickets (LT) phenomenon observed by Frankle & Carbin in [The Lottery Ticket Hypothesis: Finding Small, Trainable Neural Networks](https://arxiv.org/abs/1803.03635). | ||
|
||
``` | ||
@inproceedings{dtl | ||
title={Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask}, | ||
author={Zhou, Hattie and Lan, Janice and Liu, Rosanne and Yosinski, Jason}, | ||
booktitle={Identifying and Understanding Deep Learning Phenomena Workshop, International Conference on Machine Learning}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
For more on this project, see the [Uber Eng Blog post](https://eng.uber.com/deconstructing-lottery-tickets/). | ||
|
||
|
||
## Codebase structure | ||
- `data/download_mnist.py` downloads MNIST data and splits it into train, val, and test | ||
- `get_weight_init.py` computes various mask criteria | ||
- `masked_layers.py` defines new layer classes with masking options | ||
- `masked_networks.py` defines new layers and networks used in training Supermasks | ||
- `network_builders.py` defines the four network architecture evaluated in the paper (FC, Conv2, Conv4, Conv6) | ||
- `train.py` trains original unmasked networks | ||
- `train_lottery.py` reads in initial and final weights from a previously trained model, calculates the mask, and train a lottery style network | ||
- `train_supermask` trains a supermask directly using Bernoulli sampling | ||
- `get_init_loss_train_lottery.py` derives masks and calculates the initial accuracy of the masked network for various pruning percentages and mask criteria. Note that this uses a one-shot approach rather than an iterative approach. | ||
|
||
This codebase uses the `GitResultsManager` package to keep track of experiments. See: https://github.com/yosinski/GitResultsManager | ||
|
||
## Example commands for running experiments | ||
The following commands provide examples for running experiments in Deconstructing Lottery Tickets. | ||
|
||
#### Train the original, unpruned network | ||
- Train a FC network (300-100-10) on MNIST: `./print_train_command.sh iter fc test 0 t` | ||
|
||
#### Alternative mask criteria experiments (using FC on MNIST and large final as an example) | ||
- Perform iterative LT training for a FC network on MNIST using large final mask criterion: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask none t` | ||
|
||
#### Mask-1 experiments | ||
- Randomly reinitialize weights prior to each round of iterative retraining: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reinit t` | ||
|
||
- Randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reshuffle t` | ||
|
||
- Convert the initial values of weights to a signed constant before randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask rand_signed_constant t` | ||
|
||
- For versions that maintain the same sign, see `signed_reinit`, `signed_reshuffle`, and `signed_constant`. | ||
|
||
#### Mask-0 experiments | ||
- Freeze pruned weights at initial values: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init none t` | ||
|
||
- Freeze pruned weights that increased in magnitude at initial values: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_mask none t` | ||
|
||
- Initialize weights that decreased in magnitude at 0, and freeze pruned weights at initial value: `./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_all none t` | ||
|
||
#### Supermask experiments | ||
- Evaluate the initial test accuracy of all alternative mask criteria: `python get_init_loss_train_lottery.py --output_dir ./results/iter_lot_fc_orig/test_seed_0/ --train_h5 ./data/mnist_train --test_h5 ./data/mnist_test --arch fc_lot --seed 0 --opt adam --lr 0.0012 --exp none --layer_cutoff 4,6 --prune_base 0.8,0.9 --prune_power 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24` | ||
|
||
- Train a Supermask directly: `python train_supermask.py --output_dir ./results/iter_lot_fc_orig/learned_supermasks/run1/ --train_h5 ./data/mnist_train --test_h5 ./data/mnist_test --arch fc_mask --opt sgd --lr 100 --num_epochs 2000 --print_every 220 --eval_every 220 --log_every 220 --save_weights --save_every 22000` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import tensorflow as tf | ||
import h5py | ||
import numpy as np | ||
np.random.seed(seed=0) | ||
|
||
|
||
def main(): | ||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | ||
|
||
valset_ind = np.random.choice(range(60000), size=5000, replace=False) | ||
trainset_ind = np.array([i for i in range(60000) if i not in valset_ind]) | ||
|
||
train_set_images = x_train[trainset_ind] | ||
train_set_images = train_set_images.reshape((55000,28,28,1)) | ||
train_set_labels = y_train[trainset_ind] | ||
|
||
val_set_images = x_train[valset_ind] | ||
val_set_images = val_set_images.reshape((5000,28,28,1)) | ||
val_set_labels = y_train[valset_ind] | ||
|
||
x_test = x_test.reshape((10000,28,28,1)) | ||
|
||
f = h5py.File("mnist_train", "w") | ||
f.create_dataset('images', data=train_set_images) | ||
f.create_dataset('labels', data=train_set_labels) | ||
f.close() | ||
|
||
f = h5py.File("mnist_val", "w") | ||
f.create_dataset('images', data=val_set_images) | ||
f.create_dataset('labels', data=val_set_labels) | ||
f.close() | ||
|
||
f = h5py.File("mnist_test", "w") | ||
f.create_dataset('images', data=x_test) | ||
f.create_dataset('labels', data=y_test) | ||
f.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.