This repository contains PyTorch implementation for results presented in the paper: Vanishing Feature: Diagnosing Model Merging and Beyond.
Setup the environment by running:
conda create -n vf python=3.9 cupy pkg-config libjpeg-turbo opencv numba -c conda-forge -c pytorch && conda activate vf && conda update ffmpeg
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install -r requirements.txtThe current code for preserve-first merging (PFM) and pruning are adapted from ZipIt! and WoodFisher, which might bring some inconvenience for further research. We plan to release a improved version in the future. For the CCA merging experiments, since the code is also adapted from ZipIt!, their implementation of the matching algorithm is directly plugged into our code in the PFM/matching_functions.py file.
All the dependencies in ZipIt! code are already included in the above installation, while addiontal setup needs for our code in WoodFisher/. Please refer WoodFisher/README.md for deatails.
-
main_notebooks/contains notebooks to reproduce most results in the main paper. -
run_ex.pyis the main file for training models. It trains a pair of models simultaneously. -
run_training.pyis similar torun_ex.py, but only trains a single model at a time. -
training_scripts/contains bash scripts for training models, including the hyper-parameter settings. -
source/contains source code for models, datasets, training, merging, etc.source/utils/opts.pycontains code for parsing arguments.source/utils/weight_matching/contains code for weight matching.source/utils/activation_matching/contains code for activation matching.source/utils/connect/contains code for merging and post-merging normalization.
-
PFM/is adapted from ZipIt!, containing code to reproduce the evaluations of our prevew-first merging (PFM) framework. A complete description of the ZipIt! code can be found in the orignal repo.PFM/run_cifar.shis the bash script to run PFM experiments on CIFAR datasets.PFM/run_imagenet.shis the bash script to run PFM experiments on ImageNet dataset.PFM/run_cifar_cca.shis the bash script to run PFM experiments with CCA merging.PFM/visualize_results.ipynbis the notebook to visualize the results after running the PFM experiments.PFM/calculate_params_flops.ipynbcontains the code to calculate the number of parameters and FLOPs of the models.PFM/get_zipit_premuted_models.ipynbis used to get and save permuted models after applying ZipIt!. These models are then used for evaluating the performances of normalization methods, as shown inmain_noteooks/improve_normalization_from_vf.ipynb.
-
WoodFisher/is adapted from WoodFisher for pruning experiemnts. A complete description of the repo can be found there.WoodFisher/main.pyis the main file to run pruning from.WoodFisher/transfer_checkpoint.ipynbcontains the code to transfer pre-trained checkpoint produced in our code to fit the WoodFisher pruning code.WoodFisher/checkpointsis used to store the pre-trained models for the later pruning.WoodFisher/configscontains yaml config files used for specifying training and pruning schedules. In our work, we only utilize the pruning schedules.WoodFisher/scriptscontains the all bash scripts for pruning to reproduce the results in the paper.WoodFisher/record_pruningcontains the code for visualizing the results after pruning.WoodFisher/lmc_sourcecontains edited code from our repo for applying re-normalization after pruning.
We use the Weight & Biases (wandb) platform for logging results during training. To use wandb for the first time, you need to create an account and login. The --wandb-mode flag can be used to specify the mode of wandb. The default mode is online, which will log the results to the wandb server. If you want to run the code without logging to wandb, you can set --wandb-mode to disabled. If you want to log the results to wandb but do not want to create an account, you can set --wandb-mode to offline. In this case, the results will be logged to a local directory wandb/ and you can upload the results to wandb later. For more details, please refer to the wandb documentation.
We use a bash script to specify all training settings. The bash script is located in training_scripts/. All settings can be found in source/utils/opts.py with explanations. Here we only list some important args.
--project: The wandb project name.--run-name: The wandb run name.--dataset: The dataset to use. We usemnist,cifar10andcifar100in our work.--data-dir: The dataset directory.--model: The model to use, including VGG and ResNet type of models.- Standard plain VGG models includ
cifar_vgg11,cifar_vgg13,cifar_vgg16, andcifar_vgg19. VGG model with batch normalization is named with_bnsuffix, e.g.,cifar_vgg11_bn. - Standard ResNet models are named as
cifar_resnet[xx], e.g.,cifar_resnet20. Plain/Fixup ResNet model is named withplain_/fixup_prefix, e.g.,plain_cifar_resnet20andfixup_cifar_resnet32. - Models with layer normalization are named with
_lnsuffix, e.g.,cifar_vgg11_lnandcifar_resnet20_ln. - Models without biases are named with
_nobiassuffix, e.g.,cifar_vgg11_nobias. - Models with a larger width are named with
_[width_multipler]xat the end, e.g.,cifar_vgg16_bn_4x.
- Standard plain VGG models includ
--diff-init: Whether to use different initialization for the two models. IfTrue, the two models are initialized with different random seeds.--special-init: Whether to use special initialization for models. Default isNoneand the models are initialized with the default Kaiming uniform initialization in PyTorch. If set tovgg_init, the Kaiming normal initialization is used.--train-only: Whether to only train the model without measuring the linear interpolation between the two models during training. If not, the linear interpolation is measured every--lmc-freqpercent of the training.--reset-bn: Whether to reset BN statistics when measuring the linear interpolation during training.--repair: Whether to apply REPAIR/RESCALE when measuring the linear interpolation during training. Default isNoneand no re-normalizaiont is applied. If set torepair, REPAIR is applied. If set torescale, RESCALE is applied.
We refer a complete description to the original repo: WoodFisher. Here we only list some important args in the pruning scripts located in WoodFisher/scripts/, which are important for reproducing the results in the paper.
MODULES: The modules to prune.ROOT_DIR: The root directory to store the results.DATA_DIR: The dataset directory.PRUNERS: The pruners to use. Option:woodfisherblockglobalmagnimagnidiagfisher--num-samples: The number of samples to use for applying re-normalization. Default isNone.--from-checkpoint-path: The path to the pre-trained checkpoint.
The PFM framework is adapted from ZipIt!, where a highly flexible implementation of model merging is provided based on graph representations of models. We modify the original repo to support our models, datasets, and the preserve-first merging framework. The setup for the PFM repo can be found in PFM/README.md.
The pruning results reported in the paper are conducted based on the framework in WoodFisher. Code is stored in WoodFisher. We manually edit some code in the original repo to force a one-shot pruning and remove some irrelevant feautres, especially for the WoodFisher/policies/manager.py file, while this can also be done by modifying the pruning settings in the scripts. The original file is retained in WoodFisher/policies/manager_ori.py. For applying re-normalization after pruning, we merged a modified version of our code with the repo, sotred in WoodFisher/lmc_source/. Several lines of code are also added to WoodFisher/policies/manager.py. This can be used as an example to merge our code with other pruning frameworks.
We will release pre-trained checkpoints for re-producing the pruning results in the future. These checkpoints were already transferred and hence there is no need to run the WoodFisher/transfer_checkpoint.ipynb.
We also provide an simple and self-contained example to showcase how to apply normalization after pruning in main_notebooks/renormalize_pruned_model.ipynb. It uses torch.nn.utils.prune to prune the model and then applies normalization.
The setup for the WoodFisher repo can be found in WoodFisher/README.md.
This codebase corresponds to the paper: Vanishing Feature: Diagnosing Model Merging and Beyond. If you use any of the code or provided models for your research, please consider citing the paper as
@article{qu2024vanishing,
title={Vanishing Feature: Diagnosing Model Merging and Beyond},
author={Qu, Xingyu and Horvath, Samuel},
journal={arXiv preprint arXiv:2402.05966},
year={2024}
}