Code for the paper titled SPARTAN: Sparse Hierarchical Memory for Parameter-Efficient Transformers [Paper]
Fine-tuning pre-trained language models (PLMs) achieves impressive performance on a range of downstream tasks, and their sizes have consequently been getting bigger. Since a different copy of the model is required for each task, this paradigm is infeasible for storage-constrained edge devices like mobile phones. In this paper, we propose SPARTAN, a parameter efficient (PE) and computationally fast architecture for edge devices that adds hierarchically organized sparse memory after each Transformer layer. SPARTAN freezes the PLM parameters and fine-tunes only its memory, thus significantly reducing storage costs by re-using the PLM backbone for different tasks. SPARTAN contains two levels of memory, with only a sparse subset of parents being chosen in the first level for each input, and children cells corresponding to those parents being used to compute an output representation. This sparsity combined with other architecture optimizations improves SPARTAN's throughput by over 90%
during inference on a Raspberry Pi 4 when compared to PE baselines (adapters) while also outperforming the latter by 0.1
points on the GLUE benchmark. Further, it can be trained 34%
faster in a few-shot setting, while performing within 0.9
points of adapters. Qualitative analysis shows that different parent cells in SPARTAN specialize in different topics, thus dividing responsibility efficiently.
Step 0
: We recommend creating a conda
environment
conda create -n spartan python=3.8
Step 1
: Install the requirements
pip install -r requirements.txt
Step 2
: Setup the SPARTAN package
pip install .
Please use the following citation if you found the paper useful!
@article{deshpande2022spartan,
title={SPARTAN: Sparse Hierarchical Memory for Parameter-Efficient Transformers},
author={Deshpande, Ameet and Sultan, Md Arafat and Ferritto, Anthony and Kalyan, Ashwin and Narasimhan, Karthik and Sil, Avirup},
journal={arXiv preprint arXiv:2211.16634},
year={2022}
}