# VLM-RL Training with LoRA

Quick training setup for Qwen3-VL with LoRA fine-tuning on PPO.

## Setup Dependencies

Install required packages for Colab environment:

In [None]:
# Install specific transformers version required for Qwen3-VL
!pip install transformers==4.57.1

# Clone repo and install if running on Colab
import os
if 'COLAB_GPU' in os.environ:
    !git clone https://github.com/sdan/vlm-gym.git
    %cd vlm-gym
    !pip install -e .
else:
    # Running locally - install in editable mode if not already installed
    try:
        import vlmrl
        print("✓ vlmrl already installed")
    except ImportError:
        print("Installing vlmrl in editable mode...")
        !pip install -e .
        print("✓ vlmrl installed")

In [None]:
# Check working directory
import os
print(f"Working directory: {os.getcwd()}")

## Convert HF Checkpoint to JAX

Download and convert Qwen3-VL checkpoint from HuggingFace to JAX format for LoRA training:

In [None]:
# Convert Qwen3-VL-2B from HuggingFace to JAX format
!python -m vlmrl.utils.hf_to_jax \
    --model_type qwen3vl \
    --hf_repo Qwen/Qwen3-VL-2B-Instruct \
    --model_dir checkpoints/qwen3vl_2b

print("✓ Checkpoint converted and ready for LoRA training")

## Training Configuration

This notebook runs a quick training session with:
- **Model**: Qwen3-VL-2B-Instruct
- **LoRA**: Enabled (rank=16, alpha=32, lr_mult=10.0)
- **Training**: 50 steps, batch_size=4, PPO epochs=1
- **Optimizer**: AdamW with lr=1e-6
- **W&B**: Offline mode

In [None]:
# Training parameters
config = {
    'model_dir': 'checkpoints/qwen3vl_2b',
    'wandb_mode': 'offline',
    'total_steps': 50,
    'batch_size': 4,
    'ppo_epochs': 1,
    'learning_rate': 1e-6,
    'optimizer': 'adamw',
    'lora_enable': 1,
    'lora_rank': 16,
    'lora_alpha': 32,
    'lora_lr_mult': 10.0,
}

# Build command
cmd_parts = ['python', '-m', 'vlmrl.core.train']
for key, value in config.items():
    cmd_parts.append(f'--{key}')
    cmd_parts.append(str(value))

command = ' '.join(cmd_parts)
print("Training command:")
print(command)

## Run Training

Execute the training run. This will:
1. Load the Qwen3-VL model from checkpoint
2. Initialize LoRA adapters
3. Run PPO training for 50 steps
4. Save checkpoints to the runs directory

In [None]:
# Run training - use = instead of space for flag values
!python -m vlmrl.core.train \
    --model_dir=checkpoints/qwen3vl_2b \
    --wandb_mode=offline \
    --total_steps=50 \
    --batch_size=4 \
    --ppo_epochs=1 \
    --learning_rate=1e-6 \
    --optimizer=adamw \
    --lora_enable=1 \
    --lora_rank=16 \
    --lora_alpha=32 \
    --lora_lr_mult=10.0

## Monitor Training

Check the training outputs and W&B logs:

In [None]:
# List saved checkpoints
!ls -lh runs/

In [None]:
# View W&B offline logs
!ls -lh wandb/

## Sample Rollout

Test the trained model with a few sample episodes:

In [None]:
# Run sample rollouts on the geospot environment
!python -m vlmrl.core.rollout \
    --model_dir checkpoints/qwen3vl_2b \
    --env_name geospot \
    --episodes 5 \
    --batch_size 1 \
    --temperature 0.7 \
    --top_p 0.9 \
    --max_new_tokens 64