Skip to content

techmonsterwang/iLLaMA

Repository files navigation


Image credit: DALL·E

This is a PyTorch implementation of iLLaMA proposed by our paper "Adapting LLaMA Decoder to Vision Transformer".

iLLaMA first figure Figure 1: Left: iLLaMA architecture. Right: our design roadmap. Colored and gray bars represent the results of the tiny and base regimes, with the red line depicting the training loss of the tiny regime. iLLaMA strives to process visual tokens using standard LLaMa components, e.g., causal self-attention. The proposed PS [cls] and soft mask strategy help overcome training challenges.


iLLaMA second figure Figure 2: (a) mask in causal self-attention. (b) mask in causal self-attention with our post-sequence class token (PS [cls]) method. (c) modified causal mask.


iLLaMA third figure Figure 3: (a) Soft mask gradually transitions from a bi-directional mask into a causal mask during training through a constant or linear schedule. (b) Ablation results of training loss and test accuracy.

Requirements

PyTorch and timm 0.5.4 (pip install timm==0.5.4).

Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Models

iLLaMA on ImageNet-1K

Model Pre-trained dataset Resolution Params MACs Top1 Acc
illama_tiny - 224 5.7M 1.3G 75.0
illama_small - 224 21.9M 4.6G 79.9
illama_base - 224 86.3M 17.6G 81.6
illama_base - 384 86.3M 55.5G 83.0
illama_base ImageNet-21K 224 86.3M 17.6G 83.6
illama_base ImageNet-21K 384 86.3M 55.5G 85.0
illama_large ImageNet-21K 224 310.2M 62.8G 84.8
illama_large ImageNet-21K 384 310.2M 194.7G 86.0

Evaluate

To evaluate models on 224 resolution, run:

MODEL=illama_tiny
RESUME='/your/path/to/model.pth'

python -m torch.distributed.launch --nproc_per_node=2 main.py \
    --model $MODEL --eval true \
    --data_path $root_imagenet \
    --resume $RESUME

To evaluate models on 384 resolution, run:

MODEL=illama_base
RESUME='/your/path/to/model.pth'

python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
    --model $MODEL --input_size 384 --eval true \
    --data_path $root_imagenet \
    --resume $RESUME

Train

We use batch size of 4096 by default with 8 GPUs.

bash scripts/train_illama_tiny_in1k.sh

Training scripts of other models are shown in scripts.

Initialization Using LLaMA2-7B (Optional)

We use weight selection method to select weights from LLaMA2-7B.

python llama2/weight_selection.py

Then we use the selected weights to initialize our iLLaMA-T/S/B.

bash scripts/train_illama_tiny_from_llama2.sh

Training scripts of other models are shown in scripts.

Bibtex

@article{wang2024adapting,
  title={Adapting LLaMA Decoder to Vision Transformer},
  author={Wang, Jiahao and Shao, Wenqi and Chen, Mengzhao and Wu, Chengyue and Liu, Yong and Zhang, Kaipeng and Zhang, Songyang and Chen, Kai and Luo, Ping},
  journal={arXiv preprint arXiv:2404.06773},
  year={2024}
}

Acknowledgment

Our implementation is based on pytorch-image-models, llama, dropout, ConvNeXt, weight-selection, and MambaOut.

About

Adapting LLaMA Decoder to Vision Transformer

Resources

Stars

Watchers

Forks

Packages

No packages published