In [1]:
import sys
import os
import yaml
import torch
import json
import tqdm
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer
from environments.maze_env import MazeEnvironment
from environments.syllogism_env import SyllogismEnvironment
from environments.gsm8k import GSM8KEnvironment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_config(config_path) -> Dict[str, Any]:
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def get_environment(config: Dict[str, Any]) -> Any:
    # Need to maintain compat with classes taking dict
    env_config = config.get('environment', {})
    if not env_config:
        if config.get('data', {}).get('dataset_name') == 'openai/gsm8k':
            return GSM8KEnvironment(config)
            
    name = env_config.get('name')
    if name == 'gsm8k':
        return GSM8KEnvironment(config)
    elif name == 'maze':
        return MazeEnvironment(config)
    elif name == 'syllogism':
        return SyllogismEnvironment(config)
    else:
        raise ValueError(f"Unknown environment: {name}")

In [7]:
from data_gen.generate_sft_data import DataGenConfig

In [None]:
raw_config = load_config("/home/rlvr_experiments/configs/data_gen_maze.yaml")
    # Parse into typed config
config = DataGenConfig.from_dict(raw_config)

# Initialize Environment
# We pass raw_config to keep compatibility with environment classes that expect a dict
env = get_environment(raw_config)
dataset = env.get_dataset(raw_config)

# Load Model
print(f"Loading model: {config.model.name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(config.model.name_or_path)
model = AutoModelForCausalLM.from_pretrained(
    config.model.name_or_path,
    torch_dtype=config.model.dtype,
    device_map=config.model.device_map
)

Generating Maze dataset with config: MazeConfig(min_dist=2, max_dist=8, min_grid_size=5, max_grid_size=10, seed=42, size=10)
Increasing maze generation size from 10 to 200 to match max_samples


Map: 100%|██████████| 200/200 [00:00<00:00, 13377.25 examples/s]

Loading model: Qwen/Qwen2-7B-Instruct



Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [15]:
for i, item in tqdm.tqdm(enumerate(dataset), total=min(len(dataset), 2)):
    # item['prompt'] is a list of {'role':..., 'content':...}
    # item['answer'] is the ground truth
    
    # Extract the original user query
    user_content = next((msg['content'] for msg in item['prompt'] if msg['role'] == 'user'), None)
    system_content = next((msg['content'] for msg in item['prompt'] if msg['role'] == 'system'), None)
    
    ground_truth = item['answer']
    
    # Construct a meta-prompt to get the reasoning
    
    # Get tokens from config
    think_start = config.sft.think_start_token
    think_end = config.sft.think_end_token
    answer_start = config.sft.answer_start_token
    answer_end = config.sft.answer_end_token
    system_prompt_tmpl = config.sft.system_prompt
    
    # If the template contains format placeholders for tokens, format them.
    try:
            system_msg = system_prompt_tmpl.format(think_start=think_start, think_end=think_end, answer_start=answer_start, answer_end=answer_end)
    except KeyError:
            # Fallback if user prompt doesn't match format keys
            system_msg = system_prompt_tmpl

    meta_prompt = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": f"Problem:\n{user_content}\n\nCorrect Solution:\n{ground_truth}\n\nPlease explain the reasoning step-by-step to arrive at this solution."}
    ]
    
    text = tokenizer.apply_chat_template(meta_prompt, tokenize=False, add_generation_prompt=True)
    print(text)
    break

  0%|          | 0/2 [00:00<?, ?it/s]

<|im_start|>system
You are a helpful assistant. You are given a problem and its correct solution. 
Your task is to generate the step-by-step reasoning that leads to this solution.
<|im_end|>
<|im_start|>user
Problem:
Navigate from '3' (start) to 'z' (goal):

```
>>>>>>
>e>e>>
>eeee>
>ezee>
>ee3>>
>>>>>>
```
Legend: '>' = Wall, 'e' = Passage

What is the minimum number of steps to reach the goal?
Give only the number of steps as your final answer, no other text or formatting.

Correct Solution:
2

Please explain the reasoning step-by-step to arrive at this solution.<|im_end|>
<|im_start|>assistant






In [11]:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
# Force the model to start with the think_start token
think_ids = tokenizer(think_start, add_special_tokens=False).input_ids
# Ensure it's on the same device
think_tensor = torch.tensor([think_ids], device=model.device)

# Concatenate
current_ids = inputs.input_ids
current_mask = inputs.attention_mask

input_ids = torch.cat([current_ids, think_tensor], dim=1)
# Extend attention mask
attention_mask = torch.cat([current_mask, torch.ones((1, len(think_ids)), device=model.device)], dim=1)

with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode only the NEW tokens
generated_ids = outputs[0][input_ids.shape[1]:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)


In [12]:
generated_text

'Initially, we start by moving 1 step in any direction. We can go left, right, up, or down.</think>\n\n<answer>2</answer>'