Skip to content

zhangyikaii/Model-Spider

Repository files navigation

 

Generic badge GitHub Workflow Status (branch)
PyPI PyPI - Downloads
PyTorch - Version Python - Version

Model Spider: Learning to Rank Pre-Trained Models Efficiently (NeurIPS 2023 Spotlight)

📑 [Paper] [Code]

Detailed Introduction

Figuring out which Pre-Trained Model (PTM) from a model zoo fits the target task is essential to take advantage of plentiful model resources. With the availability of numerous heterogeneous PTMs from diverse fields, efficiently selecting the most suitable PTM is challenging due to the time-consuming costs of carrying out forward or backward passes over all PTMs. In this paper, we propose Model Spider, which tokenizes both PTMs and tasks by summarizing their characteristics into vectors to enable efficient PTM selection.
By leveraging the approximated performance of PTMs on a separate set of training tasks, Model Spider learns to construct representation and measure the fitness score between a model-task pair via their representation. The ability to rank relevant PTMs higher than others generalizes to new tasks. With the top-ranked PTM candidates, we further learn to enrich task repr. with their PTM-specific semantics to re-rank the PTMs for better selection. Model Spider balances efficiency and selection ability, making PTM selection like a spider preying on a web.
Model Spider demonstrates promising performance across various model categories, including visual models and Large Language Models (LLMs). In this repository, we have built a comprehensive and user-friendly PyTorch-based model ranking toolbox for evaluating the future generalization performance of models. It aids in selecting the most suitable foundation pre-trained models for achieving optimal performance in real-world tasks after fine-tuning. In this benchmark for selecting/ranking PTMs, we have reproduced relevant model selection methods such as H-Score, LEEP, LogME, NCE, NLEEP, OTCE, PACTran, GBC, and LFC.

  1. We introduce a single-source model zoo, building 10 PTMs on ImageNet across five architecture families, i.e., Inception, ResNet, DenseNet, MobileNet, and MNASNet. These models can be evaluated on 9 downstream datasets using measure like weighted tau, including Aircraft, Caltech101, Cars, CIFAR10, CIFAR100, DTD, Pet, and SUN397 for classification, UTKFace and dSprites for regression.
  2. We construct a multi-source model zoo where 42 heterogeneous PTMs are pre-trained from multiple datasets, with 3 architectures of similar magnitude, i.e., Inception-V3, ResNet-50, and DenseNet-201, pre-trained on 14 datasets, including animals, general and 3D objects, plants, scene-based, remote sensing, and multi-domain recognition. We evaluate the ability to select PTMs on Aircraft, DTD, and Pet datasets.

In this repo, you can figure out:

  • Implementations of Pre-trained Model Selection / Ranking (for unseen data) with an accompanying benchmark evaluation, including H-Score, LEEP, LogME, NCE, NLEEP, OTCE, PACTran, GBC, and LFC.
  • Get started quickly with our method Model Spider, and enjoy its user-friendly inference capabilities.
  • Feel free to customize the application scenarios of Model Spider!

 

Table of Contents

 

Pre-trained Model Ranking Performance

Performance comparisons of 9 baseline approaches and Model Spider on the single-source model zoo with weighted Kendall's tau. We denote the best-performing results in bold.

Method Downstream Target Dataset
Weighted Tau Aircraft Caltech101 Cars CIFAR10 CIFAR100 DTD Pets SUN397 Mean
H-Score 0.328 0.738 0.616 0.797 0.784 0.395 0.610 0.918 0.648
NCE 0.501 0.752 0.771 0.694 0.617 0.403 0.696 0.892 0.666
LEEP 0.244 0.014 0.704 0.601 0.620 -0.111 0.680 0.509 0.408
N-LEEP -0.725 0.599 0.622 0.768 0.776 0.074 0.787 0.730 0.454
LogME 0.540 0.666 0.677 0.802 0.798 0.429 0.628 0.870 0.676
PACTran 0.031 0.200 0.665 0.717 0.620 -0.236 0.616 0.565 0.397
OTCE -0.241 -0.011 -0.157 0.569 0.573 -0.165 0.402 0.218 0.149
LFC 0.279 -0.165 0.243 0.346 0.418 -0.722 0.215 -0.344 0.034
GBC -0.744 -0.055 -0.265 0.758 0.544 -0.102 0.163 0.457 0.095
Moder Spider (Ours) 0.506 0.761 0.785 0.909 1.000 0.695 0.788 0.954 0.800

 

Code Implementation

Quick Start & Reproduce

  • Set up the environment:

    conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
    conda activate modelspider
    git clone https://github.com/zhangyikaii/Model-Spider.git
    cd Model-Spider
    pip install -r requirements.txt
  • Choose your path xxx/xx to store data & model:

    source ./scripts/modify-path.sh xxx/xx
  • Download the data and pre-trained model spider here to previous path xxx/xx. (Note that the training set for model spider is sampled from EuroSAT, OfficeHome, PACS, SmallNORB, STL10 and VLCS)

  • Unzip c_data.zip to path xxx/xx/data/ and then run:

    bash scripts/test-model-spider.sh xxx/xx/best.pth

    The results will be displayed on the screen.

 

Reproduce for Other Baseline Methods

We provided results of baseline method in the assests/baseline_results.csv file. Ensure the test datasets (Aircraft, Cars, CIFAR10, CIFAR100, DTD, Pet, SUN397) are in xxx/xx/data, and run following command to reproduce them:

bash scripts/reproduce-baseline-methods.sh

 

Contributing

Model Spider is currently in active development, and we warmly welcome any contributions aimed at enhancing capabilities. Whether you have insights to share regarding pre-trained models, data, or innovative ranking methods, we eagerly invite you to join us in making Model Spider even better.

 

Citing Model Spider

@inproceedings{ModelSpiderNeurIPS23,
  author    = {Yi{-}Kai Zhang and
               Ting{-}Ji Huang and
               Yao{-}Xiang Ding and
               De{-}Chuan Zhan and
               Han{-}Jia Ye},
  title     = {Model Spider: Learning to Rank Pre-Trained Models Efficiently},
  booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference
               on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans,
               LA, USA, December 10 - 16, 2023},
  year      = {2023},
}