Skip to content
Code for paper "SWALP: Stochastic Weight Averaging forLow-Precision Training".
Python Shell
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
assets Update: include other implementation May 11, 2019
exp Initial commit May 9, 2019
models Initial commit May 9, 2019
.gitignore Initial commit May 9, 2019
LICENSE Initial commit May 9, 2019
README.md Update: include other implementation May 11, 2019
data.py Initial commit May 9, 2019
requirements.txt Initial commit May 9, 2019
train.py Initial commit May 9, 2019
utils.py Initial commit May 9, 2019

README.md

Stochastic Weight Averaging for Low-Precision Training (SWALP)

This repository contains a PyTorch implementation of the paper:

SWALP : Stochastic Weight Averaging for Low-Precision Training (SWALP).

Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa

swalp-image

Introduction

Low precision operations can provide scalability, memory savings, portability, and energy efficiency. This paper proposes SWALP, an approach to low precision training that averages low-precision SGD iterates with a modified learning rate schedule. SWALP is easy to implement and can match the performance of full-precision SGD even with all numbers quantized down to 8 bits, including the gradient accumulators. Additionally, we show that SWALP converges arbitrarily close to the optimal solution for quadratic objectives, and to a noise ball asymptotically smaller than low precision SGD in strongly convex settings.

This repo contains the codes to replicate our experiment for CIFAR datasets with VGG16 and PreResNet164.

Citing this Work

Please cite our work if you find this approach useful in your research:

@misc{gu2019swalp,
    title={SWALP : Stochastic Weight Averaging in Low-Precision Training},
    author={Guandao Yang and Tianyi Zhang and Polina Kirichenko and Junwen Bai and Andrew Gordon Wilson and Christopher De Sa},
    year={2019},
    eprint={1904.11943},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Dependencies

To install other requirements through $ pip install -r requirements.txt.

Usage

We provide scripts to run Small-block Block Floating Point experiments on CIFAR10 and CIFAR100 with VGG16 or PreResNet164. Following are scripts to reproduce experimental results.

seed=100                                      # Specify experiment seed.
bash exp/block_vgg_swa.sh CIFAR10 ${seed}     # SWALP training on VGG16 with Small-block BFP in CIFAR10
bash exp/block_vgg_swa.sh CIFAR100 ${seed}    # SWALP training on VGG16 with Small-block BFP in CIFAR100
bash exp/block_resnet_swa.sh CIFAR10 ${seed}  # SWALP training on PreResNet164 with Small-block BFP in CIFAR10
bash exp/block_resnet_swa.sh CIFAR100 ${seed} # SWALP training on PreResNet164 with Small-block BFP in CIFAR100

Results

The low-precision results (SGD-LP and SWALP) are produced by running the scripts in /exp folder. The full-precision results (SGD-FP and SWA-FP) are produced by running the SWA repo.

Datset Model SGD-FP SWA-FP SGD-LP SWALP
CIFAR10 VGG16 6.81±0.09 6.51±0.14 7.61±0.15 6.70±0.12
PreResNet164 4.63±0.18 4.03±0.10
CIFAR100 VGG16 27.23±0.17 25.93±0.21 29.59±0.32 26.65±0.29
PreResNet164 22.20±0.57 19.95±0.19

Other implementations

Tianyi Zhang provides an implementation using a low-precision training framework QPyTorch in this link.

References

We use the SWA repo as starter template. Network architecture implementations are adapted from:

You can’t perform that action at this time.