Skip to content

[NeurIPS 2023 Spotlight] Temperature Balancing, Layer-wise Weight Analysis, and Neural Network Training

License

Notifications You must be signed in to change notification settings

YefanZhou/TempBalance

Repository files navigation

Temperature Balancing, Layer-wise Weight Analysis, and Neural Network Training [NeurIPS 2023 Spotlight]

Yefan Zhou, Tianyu Pang, Keqin Liu, Charles H. Martin, Michael W. Mahoney, Yaoqing Yang

Full paper

Introduction

Temperature Balancing is a "model diagnostic" based layer-wise adaptive learning rate scheduler. Most deep neural networks have complex multilayer structures, often seen as a barrier to transparency. In our research, we reveal a significant insight: these layers are not uniformly well-trained. Our approach introduces a "model diagnostic" tool to identify underperforming layers, and address them by adjust layer-wise learning rate for improved training. We demonstrate its effectiveness across various benchmarks, datasets, and network architectures, outperforming more than five existing methods, all rooted in our ability to dissect and diagnose network imbalances.

Image description

Dataset and Environment Setup

conda create -n ww_train python=3.8
conda activate ww_train
pip install -r requirements.txt

# install TinyImageNet
export data_path=PATH_TO_DATA
bash dataset_setup.sh

Usage

from tempbalance import Tempbalance
import torch
model = ...
# initialize the scheduler
tb_scheduler = Tempbalance(net=net, 
                lr_min_ratio=0.5,
                lr_max_ratio=1.5
                )
# initialize optimizer parameter group
tb_param_group = tb_scheduler.build_optimizer_param_group(untuned_lr=0.1)
optimizer = optim.SGD(
    tb_param_group,
    ...
)
# training loop
for epoch in range(1, ...):
    ...
    train()
    test()
    # get global decayed learning rate (optimizer lr should not be updated here)
    untuned_global_lr = some_lr_decay_function(epoch)
    # temperature balancing
    tb_scheduler.step(optimizer, untuned_global_lr)
    ...

Experiments

Set up absolute paths for code repository, and checkpoints.

export data_path=PATH_TO_DATA
export src_path=PATH_TO_CODE_REPO
export ckpt_src_path=PATH_TO_CKPT

Compare TempBalance (TB) to CAL and SNR

# CIFAR100
bash bash_scripts/cal_snr/cifar100_tb_snr_cal_baseline.sh

# CIFAR10
bash bash_scripts/cal_snr/cifar10_tb_snr_cal_baseline.sh

# TinyImageNet
bash bash_scripts/cal_snr/tin_tb_snr_cal_baseline.sh

# SVHN
bash bash_scripts/cal_snr/svhn_tb_snr_cal_baseline.sh

Compare TempBalance (TB) to more baseline optimizers

bash bash_scripts/compare_opts/cifar100_compare_opt.sh

Ablation study

# Varying learning rate
bash bash_scripts/ablation/lr_ablation.sh

# Varying model width
bash bash_scripts/ablation/width_ablation.sh

# Varying ESD metric
bash bash_scripts/ablation/metric_ablation.sh

Other tasks

Preliminary implementation of applying TempBalance to language modeling and object detection, under development.

Acknowledgement

We thank the open-sourced packages weightwatcher, pytorch-optimizer, adamp.

Citation

We would appreciate it if you could cite the following paper if you found the repository useful for your work:

@inproceedings{zhou2023temperature,
title={Temperature Balancing, Layer-wise Weight Analysis, and Neural Network Training},
author={Yefan Zhou and Tianyu Pang and Keqin Liu and charles h martin and Michael W. Mahoney and Yaoqing Yang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}

About

[NeurIPS 2023 Spotlight] Temperature Balancing, Layer-wise Weight Analysis, and Neural Network Training

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published