Skip to content

Commit

Permalink
Add script.
Browse files Browse the repository at this point in the history
  • Loading branch information
liujingcs committed Nov 23, 2021
1 parent cb5bc36 commit 2f51f3f
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 0 deletions.
112 changes: 112 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Sharpness-aware Quantization for Deep Neural Networks

[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

## Recent Update

**`2021.11.23`**: We release the source code of SAQ.

## Setup the environments

1. Clone the repository locally:

```
git clone https://github.com/zhuang-group/SAQ
```

2. Install pytorch 1.8+, tensorboard and prettytable

```
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install tensorboard
pip install prettytable
```

## Data preparation

### ImageNet

Download the ImageNet 2012 dataset from [here](http://image-net.org/), and prepare the dataset based on this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).

### CIFAR-100

Download the CIFAR-100 dataset from [here](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz).

After downloading ImageNet and CIFAR-100, the file structure should look like:

```
dataset
├── imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
├── cifar100
├── cifar-100-python
│ ├── meta
│ ├── test
│ ├── train
│ └── ...
└── ...
```


## Training

### Fixed-precision quantization

1. Download the pre-trained full-precision models from the [model zoo](https://github.com/zhuang-group/SAQ/wiki/Model-Zoo).

2. Train low-precision models.

To train low-precision ResNet-20 on CIFAR-100, run:

```bash
sh script/train_qsam_cifar_r20.sh
```

To train low-precision ResNet-18 on ImageNet, run:

```bash
sh script/train_qsam_imagenet_r18.sh
```

### Mixed-precision quantization

1. Download the pre-trained full-precision models from the [model zoo](https://github.com/zhuang-group/SAQ/wiki/Model-Zoo).

2. Train the configuration generator.

To train the configuration generator of ResNet-20 on CIFAR-100, run:

```bash
sh script/train_generator_cifar_r20.sh
```

To train the configuration generator on ImageNet, run:

```bash
sh script/train_generator_imagenet_r18.sh
```

## License

This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.

## Acknowledgement

This repository has adopted codes from [SAM](https://github.com/davda54/sam), [ASAM](https://github.com/SamsungLabs/ASAM) and [ESAM](https://github.com/dydjw9/efficient_sam), we thank the authors for their open-sourced code.
1 change: 1 addition & 0 deletions script/train_cifar_r20.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train_sam.py --save_path ./output/cifar100/qresnet20/w4a4/ --data_path XXXXX --dataset cifar100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsampreresnet20 --rho 0.4 --pretrained XXXXX --qw 4.0 --qa 4.0 --quan_type LIQ_wn_qsam --experiment_id 01 --seed 01 --gpu 0 --include_aclip True
1 change: 1 addition & 0 deletions script/train_generator_cifar_r20.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train_controller.py --save_path ./output/cifar100/generator/r20/ --data_path path_of_dataset --dataset cifar100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsamspreresnet20 --rho 0.4 --pretrained path_of_pretrained_model --qw 3.0 --qa 3.0 --quan_type switchable_LIQ_wn_qsam --gpu 0 --lr_scheduler_type multi_step --n_epochs 100 --loss_lambda 1e-4 --suffix generator_01 --c_lr 5e-4 --entropy_coeff 5e-3 --target_bops 674 --include_aclip True --bits_choice 2,3,4,5 --bit_warmup_epochs 10 --wa_same_bit True
1 change: 1 addition & 0 deletions script/train_generator_imagenet_r18.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python -m torch.distributed.launch --nproc_per_node=4 --master_port=66630 --use_env train_controller.py --save_path ./output/imagenet/generator/r18/ --data_path path_of_dataset --dataset imagenet100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsamsresnet18 --rho 0.3 --pretrained /home/liujing/models/mobilenet_v2-convert.pth --qw 3.0 --qa 3.0 --quan_type switchable_LIQ_wn_qsam --gpu 4,5,6,7 --lr_scheduler_type multi_step --n_epochs 100 --loss_lambda 5e-3 --suffix controller_rho0.3_unshare_include_aclip_tb5.32_multi_step_lr0.01_warmup10 --c_lr 5e-4 --entropy_coeff 5e-3 --target_bops 5.32 --include_aclip True --bits_choice 2,3,4,5 --bit_warmup_epochs 10 --batch_size 64 --val_num 50000 --wa_same_bit True
1 change: 1 addition & 0 deletions script/train_imagenet_r18.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python -m torch.distributed.launch --nproc_per_node=4 --master_port=65535 --use_env train_sam.py --save_path ./output/imagenet/qresnet18/w4a4/ --data_path path_of_dataset --dataset imagenet --lr 0.02 --clip_lr 0.02 --opt_type QSAM_SGD --network qsamresnet18 --rho 0.3 --pretrained path_of_pretrained_model --qw 4.0 --qa 4.0 --quan_type LIQ_wn_qsam --seed 01 --gpu 0,1,2,3 --include_aclip True --batch_size 128 --n_epochs 90 --lr_scheduler_type cosine --n_threads 8 --experiment_id 01

0 comments on commit 2f51f3f

Please sign in to comment.