A two-stage method that steers a frozen LLM through reasoning trees with hyperbolic geometric guidance.
- Stage 1 trains a small head that embeds reasoning-tree states into a Poincaré ball so that distance-to-origin tracks solution-proximity (target reachable in fewer remaining steps ⇔ point closer to origin).
- Stage 2 trains a fresh LoRA + a small up-projector on top of the
frozen base LLM using DAgger with a tree oracle. The current policy
rolls out trajectories; the oracle labels winning operations at each
reached state; cross-entropy supervises the LoRA on those labels. The
frozen head's geometric
zis injected as a virtual token at every step boundary.
pip install -r requirements.txtpython data/generate_24_splits.py
python data/download_24_tot_test.py1.2 Pre-cache the hidden-state trees (Stage-1 input)
bash scripts/run_gen_tree_data_g24.shbash scripts/run_train_head.sh configs/head_qwen14b_origin_ranking.yamlbash scripts/run_train_stage2_dagger_g24.sh z 1234
# Ablations
bash scripts/run_train_stage2_dagger_g24.sh noz 1234
bash scripts/run_train_stage2_dagger_g24.sh randz 1234 bash scripts/run_tot_g24.shDefaults: n_generate=1, n_evaluate=3, n_select=5, T=0.7, single-model
(same base for propose + evaluate), chat-template prompts.
python data/prepare_pt_g24_data.py
bash scripts/run_train_pt_sft.sh configs/sft_pt_24_qwen14b.yaml
python -m src.eval_pt_g24 \
--lora_adapter checkpoints/sft_pt_24_qwen14b \
--test_data data/24_test_tot.jsonl \
--output results/pt_sft_24/generations.jsonl
python -m src.score_ood --task g24 --input results/pt_sft_24/generations.jsonlpython data/generate_data_rulechain.pybash scripts/run_gen_tree_data_rulechain.shWrites per-problem tree metadata + hidden-state memmaps to
data/rulechain_trees_qwen14b/{train,val,test}/.
bash scripts/run_train_head.sh configs/head_rulechain_qwen14b.yamlbash scripts/run_train_stage2_dagger_rulechain.sh z # main HyperGuide
bash scripts/run_train_stage2_dagger_rulechain.sh noz # ablation
bash scripts/run_train_stage2_dagger_rulechain.sh randz # noise controlThe launcher trains DDP across detected GPUs, runs inference via
src.eval_ood_generic, and scores with src.score_ood.
bash scripts/run_tot_rulechain.sh
python -m src.score_ood --task rulechain --input results/tot_rulechain/generations.jsonlThe rulechain ToT adapter proposes one forward-chaining rule application
per step, scores each candidate as sure/likely/impossible, and selects
the top-5 by value sum.
# 1. Build PT-augmented training trajectories
python data/prepare_pt_rulechain_data.py
# 2. SFT a LoRA
bash scripts/run_train_pt_sft.sh configs/sft_pt_rulechain_qwen14b.yaml
# 3. Evaluate + score
python -m src.eval_pt_ood \
--task rulechain \
--lora_adapter checkpoints/sft_pt_rulechain_qwen14b \
--test_data data/rulechain_test.jsonl \
--output results/pt_sft_rulechain/generations.jsonl
python -m src.score_ood --task rulechain --input results/pt_sft_rulechain/generations.jsonl