Skip to content

princetonvisualai/ReFINE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ReFINE: Reinforced Fast Weights with Next Sequence Prediction

This repository contains the official implementation of ReFINE, introduced in the paper:

Reinforced Fast Weights with Next Sequence Prediction
Hee Seung Hwang*, Xindi Wu*, Sanghyuk Chun, Olga Russakovsky


🔍 Overview

Fast weight architectures (e.g., LaCT, DeltaNet, GatedDeltaNet) are typically pre-trained with next-token prediction (NTP), which provides only token-level supervision. ReFINE addresses this limitation by optimizing for Next-Sequence Prediction (NSP) via reinforcement learning.

ReFINE is phase-agnostic and can be applied during:

  • Mid-training - Training on long-context corpora
  • Post-training - Task-specific fine-tuning
  • Test-time training - Adaptation at inference time

🧠 Method Summary

ReFINE improves sequence-level understanding through four key steps:

  1. Entropy-Based Token Selection → Select informative positions based on NTP entropy

  2. Rollout Generation → Generate multi-token continuations from truncated prefixes

  3. Reward Assignment → Compute sequence-level rewards using cosine similarity (or exact match)

  4. Optimization with RL → Optimize NSP using GRPO, combined with standard NTP loss

📊 Results

REFINE consistently improves long-context performance over supervised fine-tuning (SFT):

Needle-in-a-Haystack (RULER)

Multi-Document QA (RULER)

LongBench

Results on 12 tasks with up to 16K context length:

See the paper for detailed tables and ablations.

🚀 Getting Started

Installation

  1. Create conda environment:
conda create -n refine python=3.12 -y
conda activate refine
  1. Install verl dependencies:
bash ./verl/scripts/install_refine.sh
  1. Install additional dependencies:
pip install -r requirements.txt

Models

Download the pre-trained fast weight models:

Model Parameters Code Checkpoints
LaCT 760M GitHub HuggingFace
DeltaNet-1.3B 1.3B GitHub HuggingFace

Mid-Training

Train REFINE on long-context data:

  1. Prepare Dataset: The original Long-Data-Collections dataset is no longer available. We recommend using the SlimPajama-6B dataset instead:

    • Download the parquet files.
    • Filter for samples with at least 16K tokens (only for train data)
  2. Configure Script: Update the variables in verl/examples/refine_trainer/demo/run_midtrain_demo.sh

  3. Run Training:

    cd verl/examples/refine_trainer/demo
    bash run_midtrain_demo.sh

Post-Training

Fine-tune on task-specific long-context data:

  1. Use Provided Datasets: Post-training datasets are available in data/ruler/

  2. Configure Script: Update the variables in verl/examples/refine_trainer/demo/run_posttrain_demo.sh

  3. Run Training:

    cd verl/examples/refine_trainer/demo
    bash run_posttrain_demo.sh

Test-Time Training

Adapt the model at test time for specific tasks:

  1. Use Provided Dataset: LongBench dataset (filtered for <16K tokens) is included

  2. Configure Script: Update the variables in verl/examples/refine_trainer/demo/run_testtimetrain_demo.sh

  3. Run Training:

    cd verl/examples/refine_trainer/demo
    bash run_testtimetrain_demo.sh

Evaluation

We recommend using the demo scripts for validation (e.g. Ruler SQuADQA, HotpotQA, LongBench). For evaluation with LM-Eval-Harness (e.g. RULER NIAH), please follow the instructions here.

📝 Citation

If you find this work helpful, please cite our paper:

@article{refine2026,
  title={Reinforced Fast Weights with Next Sequence Prediction},
  author={TBD},
  journal={TBD},
  year={2026}
}

(BibTeX will be updated upon publication)

🙏 Acknowledgments

This project builds upon verl for distributed RL training infrastructure.

📚 References

  • Zhang, Tianyuan, et al. "Test-time training done right." arXiv preprint arXiv:2505.23884 (2025).

  • Yang, Songlin, et al. "Parallelizing linear transformers with the delta rule over sequence length." Advances in Neural Information Processing Systems 37 (2024): 115491-115522.

  • Yang, Songlin, Jan Kautz, and Ali Hatamizadeh. "Gated delta networks: Improving mamba2 with delta rule." arXiv preprint arXiv:2412.06464 (2024).

  • Gao, Leo, et al. "The pile: An 800gb dataset of diverse text for language modeling." arXiv preprint arXiv:2101.00027 (2020).

  • Hsieh, Cheng-Ping, et al. "RULER: What's the Real Context Size of Your Long-Context Language Models?." arXiv preprint arXiv:2404.06654 (2024).

  • Bai, Yushi, et al. "Longbench: A bilingual, multitask benchmark for long context understanding." arXiv preprint arXiv:2308.14508 (2023).

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages