Skip to content

sharma-yash01/MiniGridPT

Repository files navigation

MiniGridPT

Post-training package for MiniGridEnv: connects to a running MiniGrid OpenEnv server only via WebSocket (no dependency on the MiniGridEnv Python package). Uses TRL GRPOTrainer with a custom rollout_func and optional cross-episodic memory (memory/*.md).

Setup

cd MiniGridPT
pip install -e ".[vllm]"
# Start MiniGridEnv elsewhere, e.g. from MiniGridEnv: uv run server

Training

python -m training.grpo_train \
  --env_base_url http://127.0.0.1:8000 \
  --model Qwen/Qwen2.5-1.5B-Instruct \
  --level GoToRedBall \
  --output_dir runs/minigrid_v1

Memory mode:

python -m training.grpo_train \
  --env_base_url http://127.0.0.1:8000 \
  --model Qwen/Qwen2.5-1.5B-Instruct \
  --level GoToRedBall \
  --memory \
  --memory_dir ./memory \
  --output_dir runs/minigrid_mem

Branch-stable memory (GRPO): name files by generation index k (rank{R}_br{k}_train.md) so each of the num_generations parallel chains keeps one file across training steps. Prefer --per_device_train_batch_size equal to --num_generations.

python -m training.grpo_train \
  --env_base_url http://127.0.0.1:8000 \
  --model Qwen/Qwen2.5-1.5B-Instruct \
  --level GoToRedBall \
  --memory \
  --memory-branch-stable \
  --per_device_train_batch_size 8 \
  --num_generations 8 \
  --memory_dir ./memory \
  --output_dir runs/minigrid_mem_branch

Optional: --space_url instead of --env_base_url for Hugging Face Spaces (same resolution as ReasoningEconomicsPT).

Inference

python -m inference.run_episode --env_base_url http://127.0.0.1:8000 --level GoToRedBall --model Qwen/Qwen2.5-1.5B-Instruct
python -m inference.run_eval --env_base_url http://127.0.0.1:8000 --levels GoToRedBall,GoToObj --model Qwen/Qwen2.5-1.5B-Instruct --n_episodes 10
python -m inference.run_interactive --env_base_url http://127.0.0.1:8000 --level GoToRedBall

Scripts

  • scripts/preflight.sh — import checks and optional HTTP health
  • scripts/launch_train.sh, scripts/launch_curriculum.sh, scripts/launch_eval.sh — thin wrappers (set ENV_URL, etc.)

Lambda (A100 / Lambda Cloud)

These mirror the ReasoningEconomicsPT workflow: bootstrap a venv with caches on the Lambda NFS mount, run preflight, then launch GRPO (optional multi-GPU: trl vllm-serve + accelerate launch). You must run a MiniGrid OpenEnv HTTP server separately and point ENV_BASE_URL at it.

1. One-time bootstrap (creates venv, installs requirements.lambda.txt, sets PIP_CACHE_DIR / HF_HOME / TMPDIR under /lambda/nfs/.../minigridpt/cache):

export MGPT_ROOT=/path/to/MiniGridPT
export MGPT_VENV=/path/to/.venvs/minigridpt-lambda
export MGPT_FS_NAME=<lambda-filesystem-name>   # optional if MGPT_DATA_ROOT is set
export PYTORCH_WHEEL_INDEX=https://download.pytorch.org/whl/cu121
bash scripts/bootstrap_lambda.sh

2. Preflight (requires ENV_BASE_URL with a healthy /health):

source "$MGPT_VENV/bin/activate"
export ENV_BASE_URL=http://127.0.0.1:8000
bash scripts/preflight_lambda.sh

3. Training run:

source "$MGPT_VENV/bin/activate"
export MGPT_ROOT=/path/to/MiniGridPT
export MGPT_VENV=/path/to/.venvs/minigridpt-lambda
export ENV_BASE_URL=http://127.0.0.1:8000
# Optional: MGPT_MODEL, MGPT_LEVEL, MGPT_OUTPUT_DIR, MGPT_VLLM_MODE, MGPT_GPU_LIST, MGPT_MEMORY=1, ...
bash scripts/run_grpo_lambda.sh

Use bash scripts/run_grpo_lambda.sh --dry-run to print the resolved accelerate / python -m training.grpo_train command without executing. For FSDP on large models, set MGPT_MODEL_SHARDING=1 (requires server-style multi-GPU layout). See scripts/run_grpo_lambda.sh header for the full environment variable list.

Action parser

training/openenv_runtime.parse_action is a deliberate copy of the logic in MiniGridEnv/env/action_parser.py. If you change canonical actions in the env, update both.

Reset level / level_name

MiniGridEnv accepts level or level_name in the WebSocket reset payload so one server can run a curriculum. MiniGridPT passes this on each episode reset.

About

Post-training code for LLM/LRMs for the MiniGridEnv RL environment defining a simple game which is meant to elicit optimal memory storage with in-episode and cross-episode properties.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors