In [2]:
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
np.random.seed(42)
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
DATA_DIR = "/xxx/science-dataset/data/train-00000-of-00001-2fe71122e39b8e2d.parquet"
parquet_data = pd.read_parquet(DATA_DIR)[["question", "correct_answer", "support"]]
sampled_parquest_data = parquet_data.sample(100, random_state=42)


In [None]:
model_base_dir = "/xxx/qwen-25-1_5b"
model_r1_dir = "/xxx/openbookqa_rollout8_8k_ent1e4/global_step_620/actor/huggingface"
model_gsm8k_dir = "/xxx/verl_grpo_example_gsm8k_3/qwen25_1_5b_firstry_10epoch/global_step_580/actor/huggingface"


In [5]:
model_base = AutoModelForCausalLM.from_pretrained(
    model_base_dir,
    torch_dtype=torch.bfloat16,
    device_map= "cuda:1"
)
model_r1 = AutoModelForCausalLM.from_pretrained(
    model_r1_dir,
    torch_dtype=torch.bfloat16,
    device_map= "cuda:1"
)
model_gsm8k = AutoModelForCausalLM.from_pretrained(
    model_gsm8k_dir,
    torch_dtype=torch.bfloat16,
    device_map= "cuda:1"
)
tokenizer = AutoTokenizer.from_pretrained(model_base_dir)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [7]:
prompt = "Think about the following questions step by step.\n question:\n {}"
all_data = []
T = 0.6
MAX_LEN = 1024
for i in range(len(sampled_parquest_data)):
    question = sampled_parquest_data.iloc[i]["question"]
    support = sampled_parquest_data.iloc[i]["support"]
    correct_answer = sampled_parquest_data.iloc[i]["correct_answer"]
    messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt.format(question)}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model_base.device)

    generated_ids_base = model_base.generate(
        **model_inputs,
        max_new_tokens=MAX_LEN,
        temperature=T,
    )
    generated_ids_base = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids_base)
    ]

    response_base = tokenizer.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
    # print("response_base:",response_base, len(generated_ids_base[0]))
    #=================================
    generated_ids_r1 = model_r1.generate(
        **model_inputs,
        max_new_tokens=MAX_LEN,
        temperature=T,
        repetition_penalty=1.1
    )
    generated_ids_r1 = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids_r1)
    ]
    response_r1 = tokenizer.batch_decode(generated_ids_r1, skip_special_tokens=True)[0]
    # print("response_r1:",response_r1,len(generated_ids_r1[0]))
    #=================================
    generated_ids_gsm8k = model_gsm8k.generate(
        **model_inputs,
        max_new_tokens=MAX_LEN,
        temperature=T,
        repetition_penalty=1.1
    )
    generated_ids_gsm8k = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids_gsm8k)
    ]
    response_gsm8k = tokenizer.batch_decode(generated_ids_gsm8k, skip_special_tokens=True)[0]
    #print("response_gsm8k:",response_gsm8k,len(generated_ids_gsm8k[0]))
    all_data.append([question, support, correct_answer, response_base, response_r1, response_gsm8k,
                     len(generated_ids_base[0]), len(generated_ids_r1[0]), len(generated_ids_gsm8k[0])])
    print(i, len(generated_ids_base[0]), len(generated_ids_r1[0]), len(generated_ids_gsm8k[0]))
    # print("="*20)
    # if i > 4:break 

0 178 375 204
1 304 222 224
2 106 174 146
3 431 325 205
4 306 318 340
5 393 308 265
6 307 258 176
7 130 328 221
8 368 255 169
9 478 379 309
10 148 322 241
11 80 227 330
12 324 265 246
13 417 362 334
14 301 93 272
15 311 286 191
16 240 383 233
17 206 287 253
18 358 357 205
19 308 295 200
20 363 293 297
21 282 250 223
22 234 439 243
23 300 256 244
24 208 263 194
25 291 496 194
26 380 341 326
27 13 105 139
28 347 197 262
29 374 414 309
30 142 231 284
31 506 429 225
32 205 308 256
33 300 452 313
34 71 167 216
35 364 358 360
36 51 65 211
37 321 359 218
38 307 250 305
39 81 71 139
40 223 280 249
41 146 231 184
42 16 291 192
43 366 349 312
44 166 131 293
45 456 350 260
46 326 144 307
47 88 205 63
48 160 287 220
49 73 229 197
50 254 240 307
51 347 317 299
52 306 368 293
53 208 262 138
54 25 177 123
55 15 226 146
56 266 314 263
57 320 326 309
58 262 227 202
59 62 240 64
60 569 525 251
61 340 365 268
62 63 224 71
63 549 415 352
64 310 312 252
65 344 348 253
66 238 300 228
67 299 343 340
68 556 3

In [8]:
print(all_data[0])

['What part of a plant protects the plant cell, maintains its shape, and prevents excessive uptake of water?', '', 'wall', "To answer this question, let's break it down into parts:\n\n1. **Protecting the Plant Cell:**\n   - The primary component that provides protection for the plant cell is the cell wall.\n\n2. **Maintaining Shape:**\n   - Another important function of the cell wall is to maintain the structural integrity of the plant cells. It helps in providing rigidity and support, which is crucial for maintaining the overall shape of the plant.\n\n3. **Preventing Excessive Uptake of Water:**\n   - The cell wall also plays a role in regulating the movement of substances across the cell membrane. This includes controlling the amount of water that enters or leaves the cell through processes like osmosis and diffusion.\n\nTherefore, the part of a plant that performs all these functions—protection, maintenance of shape, and regulation of substance transport—is the **cell wall**.", "To 

In [18]:
save_data = pd.DataFrame(all_data, columns=["question", "support", "correct_answer", "response_base", "response_rl_choice", "response_rl_gsm8k",
                                             "len_base", "len_rl_choice", "len_rl_gsm8k"])

In [19]:
print(np.mean(save_data["len_base"]), np.mean(save_data["len_rl_choice"]), np.mean(save_data["len_rl_gsm8k"]))

273.52 283.81 235.15


In [None]:
idx = 1
print(save_data['question'].iloc[idx])
print('='*10)
print(save_data['response_base'].iloc[idx])
print('='*10)
print(save_data["response_rl_choice"].iloc[idx])

What part of a plant protects the plant cell, maintains its shape, and prevents excessive uptake of water?
To determine what part of a plant protects the plant cells, maintains their shape, and prevents excessive uptake of water, let's break down these functions one by one:

1. **Protecting Plant Cells:**
   - The primary protective layer in plants is the cuticle on the outer surface of leaves and stems. This cuticle acts as a barrier against pathogens, insects, and extreme environmental conditions like high humidity or low temperatures.

2. **Maintaining Shape:**
   - In most plants, especially woody ones, the main structural support comes from secondary xylem (wood). Secondary xylem provides strength and rigidity to the stem, helping it maintain its upright shape even under heavy loads.

3. **Preventing Excessive Uptake of Water:**
   - To prevent excessive water absorption, some plants have specialized structures called guard cells in their stomata. Guard cells control the opening a

In [20]:
save_data.to_csv("base_rl_response_ent0_penalty1_1_temp06.csv", index=False)