Skip to content

WailordHe/DenseSSM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DenseSSM

This is a PyTorch implementation of DenseRetNet, as described in the paper DenseMamba: State Space Models with Dense Hidden Connections for Efficient Large Language Models.

Overview

Large language models (LLMs) face a daunting challenge due to the excessive computational and memory requirements of the commonly used Transformer architecture. While state space model (SSM) is a new type of foundational network architecture offering lower computational complexity, their performance has yet to fully rival that of Transformers. This paper introduces DenseSSM, a novel approach to enhance the flow of hidden information between layers in SSMs. By selectively integrating shallow-layer hidden states into deeper layers, DenseSSM retains fine-grained information crucial for the final output. Dense connections enhanced DenseSSM still maintains the training parallelizability and inference efficiency. The proposed method can be widely applicable to various SSM types like RetNet and Mamba. With similar model size, DenseSSM achieves significant improvements, exemplified by DenseRetNet outperforming the original RetNet with up to 5% accuracy improvement on public benchmarks.

DenseMamba

DenseRetNet

How to run

Training datasets

The paper uses a subset of 15B tokens of The Pile Dataset (https://pile.eleuther.ai/) for training.

Requirements

To run our code, first install requirements:

pip install -r requirements.txt

Command

For example,if you want to pretrain the dense_gau_retnet_350m model on a single gpu, you can run:

python -m torch.distributed.launch --use_env --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr=<your_random_port> train.py \
  --model_name_or_path modeling/dense_gau_retnet_350m
  --dataset_dir <your_path_to_dataset>
  --data_cache_dir <your_path_to_data_cache_dataset>
  --validation_split_percentage 0.001
  --per_device_train_batch_size 2
  --per_device_eval_batch_size 1
  --do_train
  --num_train_epochs 1 
  --seed 1995
  --data_seed 1995
  --lr_scheduler_type polynomial
  --learning_rate 6e-4
  --warmup_ratio 0.015
  --weight_decay 0.01
  --logging_strategy steps
  --save_strategy steps
  --save_steps 1000
  --save_total_limit 10
  --gradient_accumulation_steps 16 
  --preprocessing_num_workers 4
  --block_size 2048
  --output_dir <your_path_to_output_dir>
  --overwrite_output_dir
  --evaluation_strategy "no"
  --report_to tensorboard
  --logging_dir <your_path_to_tensorboard_dir>
  --logging_steps 1
  --model_max_length 2048
  --debug_mode False
  --fp16 True
  --ddp_find_unused_parameters True
  --adam_beta2 0.98
  --prediction_loss_only True

Inference speed

Use inference_test.py to inference. This figure shows end-to-end generate speed of generating different tokens with batch_size=6.

Related projects

  • An unofficial huggingface compatible pretraining checkpoints for DenseRetNet-350MB
  • An unofficial huggingface compatible pretraining checkpoints for DenseRetNet-1.3B

Future work

  • Training code for DenseRetNet
  • Better inference code
  • Chunk-wise Recurrence Mode
  • Codes for DenseMamba

Citation

@misc{he2024densemamba,
    title={DenseMamba: State Space Models with Dense Hidden Connection for Efficient Large Language Models},
    author={Wei He and Kai Han and Yehui Tang and Chengcheng Wang and Yujie Yang and Tianyu Guo and Yunhe Wang},
    year={2024},
    eprint={2403.00818},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

Acknowledgements

We would like to thank the following projects for their contributions to our work:

About

A repository for DenseSSMs

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages