This repository is the official PyTorch implementation of paper: Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks (ICCV2023).
- Python 3
- PyTorch, torchvision
- spikingjelly 0.0.0.0.12
- Python packages:
pip install tqdm progress torchtoolbox
We use one single V100 or A100 GPU for running all the experiments. Multi-GPU training is not supported in the current codes.
CIFAR-10, CIFAR-100, DVS-CIFAR10, and DVS-Gesture:
# CIFAR-10
python train_SLTT.py -data_dir ./data -dataset cifar10 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-5
# CIFAR-100
python train_SLTT.py -data_dir ./data -dataset cifar100 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-4
# DVS-CIFAR10
python train_SLTT.py -data_dir ./data/CIFAR10DVS -dataset DVSCIFAR10 -T 10 -amp -drop_rate 0.3 -model spiking_vgg11_bn -lr=0.05 -weight_decay 5e-4 -mse_n_reg
# DVS-Gesture
python train_SLTT.py -data_dir ./data/dvsgesture -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -amp -drop_rate 0.4 -weight_decay 5e-4
ImageNet:
### Stage 1: train the models with T=1.
# The following three commands correspond to ResNet-34, ResNet-50, and ResNet-101, respectively.
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 1 -amp -epochs 100 -b 256 -T_max 100 -weight_decay 1e-5 -loss_lambda 0.0 -model spiking_nfresnet34
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 1 -amp -epochs 100 -b 256 -T_max 100 -weight_decay 1e-5 -loss_lambda 0.0 -model spiking_nfresnet50
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 1 -amp -epochs 100 -b 256 -T_max 100 -weight_decay 1e-5 -loss_lambda 0.0 -model spiking_nfresnet101
### Rename the stage-1 checkpoints to "imagenet_resnet34_t1.pth", "imagenet_resnet50_t1.pth", and "imagenet_resnet101_t1.pth", respectively.
### Move the renamed checkpoints to ./logs
### Stage 2: Fine-tune the stage-1 models with T=6 and fewer epochs.
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 6 -amp -epochs 30 -b 256 -T_max 30 -weight_decay 0.0 -loss_lambda 0.0 -model spiking_nfresnet34 -lr 0.001 -pre_train ./logs/imagenet_resnet34_t1.pth
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 6 -amp -epochs 30 -b 256 -T_max 30 -weight_decay 0.0 -loss_lambda 0.0 -model spiking_nfresnet50 -lr 0.001 -pre_train ./logs/imagenet_resnet50_t1.pth
python train_SLTT.py -data_dir ./data/imagenet/ -dataset imagenet -T 6 -amp -epochs 30 -b 256 -T_max 30 -weight_decay 0.0 -loss_lambda 0.0 -model spiking_nfresnet101 -lr 0.001 -pre_train ./logs/imagenet_resnet101_t1.pth
The stage-1 models for ImageNet can be downloaded here.
For DVS-Gesture, DVS-CIFAR10, and ImageNet, please run the following example codes:
# DVS-Gesture
python train_SLTTK.py -K 4 -data_dir ./data/dvsgesture -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -amp -drop_rate 0.4 -weight_decay 5e-4
# DVS-CIFAR10
python train_SLTTK.py -K 2 -data_dir ./data/CIFAR10DVS/ -dataset DVSCIFAR10 -T 10 -amp -drop_rate 0.3 -model spiking_vgg11_bn -lr=0.05 -weight_decay 5e-4 -mse_n_reg
# ImageNet
python train_SLTTK.py -K 2 -data_dir ./data/imagenet/ -dataset imagenet -T 6 -amp -epochs 30 -b 256 -T_max 30 -weight_decay 0.0 -loss_lambda 0.0 -model spiking_nfresnet101 -lr 0.001 -pre_train ./logs/imagenet_resnet101_t1.pth
We also provide the BPTT implementation for comparison. For running the BPTT method, please refer directly to the commands for SLTT while changing "train_SLTT.py" to "train_BPTT.py". We give an example for CIFAR-10 as following.
python train_BPTT.py -data_dir ./data -dataset cifar10 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-5
The code for data preprocessing and neuron models is based on the spikingjelly repo. The code for some utils is from the pytorch-classification repo.