Skip to content
/ Cable Public

Context-aware Biases for Length Extrapolation

Notifications You must be signed in to change notification settings

axiomlab/Cable

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

51 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Context-aware Biases for Length Extrapolation

arXiv Model

The source code of (Context-aware Biases for Length Extrapolation)

πŸš€ News

  • [2025.02.3] Code release

Upcoming

  • Cleaning codebase
  • Efficient Implementation of method
  • Adding scripts for training ALiBi, RoPE, T5-bias

Datasets and Models

Download the datasets from HuggingFace and use use src/dataset_preparation.py for saving tokenized dataset.

Some of trained models:

Dataset Model Parameters Sequence Length Checkpoint
Fineweb-Edu(10B) GPT-Medium 334M 1024 Model
Fineweb-Edu(10B) GPT-Medium 334M 512 Model
WikiText-103 GPT-Tiny 44M 1024 Model
WikiText-103 GPT-Tiny 44M 512 Model

You can also use our pre-trained models from huggingface using transformers AutoModel:

from transformers import AutoModel

cable_fineweb_md_1024 = AutoModel.from_pretrained("axiomlaborg/Cable", trust_remote_code=True, revision = "cable-edufineweb-md-1024")

cable_fineweb_md_512 = AutoModel.from_pretrained("axiomlaborg/Cable", trust_remote_code=True, revision = "cable-edufineweb-md-512")

cable_wiki_tiny_1024 = AutoModel.from_pretrained("axiomlaborg/Cable", trust_remote_code=True, revision = "cable-wiki-tiny-1024")

cable_wiki_tiny_512 = AutoModel.from_pretrained("axiomlaborg/Cable", trust_remote_code=True, revision = "cable-wiki-tiny-512")

or downloading state_dicts:

from huggingface_hub import hf_hub_download

# Specify the model ID and the filename you want to download
model_id = "axiomlaborg/Cable"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

sd_wiki_512_path = hf_hub_download(repo_id=model_id, filename='Cable_wiki_512.pt')
sd_wiki_1024_path = hf_hub_download(repo_id=model_id, filename='Cable_wiki_1024.pt')
sd_fienweb_512_path = hf_hub_download(repo_id=model_id, filename='Cable_fineweb_512.pt')
sd_fienweb_1024_path = hf_hub_download(repo_id=model_id, filename='Cable_fineweb_1024.pt')

Cable_wiki_512 = Cable(CableConfig(vocab_size=50304, n_layer=6, n_head=8, n_embd=512, block_size=512))
Cable_wiki_1024 = Cable(CableConfig(vocab_size=50304, n_layer=6, n_head=8, n_embd=512, block_size=1024))
Cable_fineweb_512 = Cable(CableConfig(vocab_size=50304, n_layer=24, n_head=16, n_embd=1024, block_size=512))
Cable_fineweb_1024 = Cable(CableConfig(vocab_size=50304, n_layer=24, n_head=16, n_embd=1024, block_size=1024))

Cable_wiki_512.load_state_dict(sd_wiki_512)
Cable_wiki_1024.load_state_dict(sd_wiki_1024)
Cable_fineweb_512.load_state_dict(sd_fineweb_512)
Cable_fineweb_1024.load_state_dict(sd_fineweb_1024)

Training

  • Single GPU

    python train.py --dataset-dir "path to dataset" --model "medium or small or tiny" --save-dir "dir for logs"
  • Multiple GPUs

    torchrun --standalone --nproc_per_node=2 train.py

For Hellaswag benchmark and evaluating extrapolation please use src/evaluation.ipynb notebook.

Length Extrapolation

A Cable model trained on T=1024 can extrapolate on T=8192, achieving a better performance (PPL=22.22) compared to the sinusoidal model (PPL=22.81) trained on T=8192.

Runtime and Memory Overhead

Cable improves the model's extrapolation ability significantly with a negligible burden in time and memory compared to the vanilla transformer. Furthermore, compared to existing RPE methods, our approach maintains nearly identical training time and GPU memory usage, while its inference overhead remains either negligible or comparable, depending on the sequence length.

Citation

If you use this repository for your research or wish to refer to our positional encoding method, please use the following BibTeX entry:

@article{veisi2025context,
  title={Context-aware Biases for Length Extrapolation},
  author={Ali Veisi and Amir Mansourian},
  journal={arXiv preprint arXiv:2503.08067},
  year={2025}
}

Acknowledgement

This repo is based on Karpathy/Build-NanoGPT. Thanks for their excellent work.

License

MIT