In [7]:
import re


def extract_number_from_answer(ans):
    """从多种可能的答案格式中提取数字"""
    # 尝试多种格式：<answer> 数字 </answer>, <div class="answer"> 数字 </div> 等
    patterns = [
        r"<answer>\s*(\d+)\s*</answer>",               # <answer> 1 </answer>
        r"<div class=\"answer\">\s*(\d+)\s*</div>",    # <div class="answer"> 1 </div>
        r"<div class='answer'>\s*(\d+)\s*</div>",      # <div class='answer'> 1 </div>
        r"<div>\s*(\d+)\s*</div>",                     # <div> 1 </div>
        r"答案是:\s*(\d+)",                            # 答案是: 1
        r"答案是\s*(\d+)",                             # 答案是 1
        r"答案：\s*(\d+)",                             # 答案： 1
        r"answer is:?\s*(\d+)",                        # answer is: 1 或 answer is 1
        r"the answer is:?\s*(\d+)",                    # the answer is: 1
        r"(\d+)"                                       # 最后尝试直接匹配数字（最宽松）
    ]
    
    for pattern in patterns:
        match = re.search(pattern, ans, re.IGNORECASE)
        if match:
            return int(match.group(1))
    
    return None

ans = "The answer is <answer> 3 </answer>."
ans2 = 'div class="<answer"> 1 </div>'
match = extract_number_from_answer(ans)
match2 = extract_number_from_answer(ans2)
print(match)
print(match2)

3
1


In [1]:
# 测试SAT数据集的视觉问答能力
import os
import json
import torch
import random
from PIL import Image
from tqdm import tqdm
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# 设置路径
current_dir = os.getcwd()
models_dir = os.path.join(current_dir, "models")
data_dir = os.path.join(current_dir, "data")
results_dir = os.path.join(current_dir, "results")
sat_dir = os.path.join(data_dir, "SAT")

# 确保文件夹存在
os.makedirs(results_dir, exist_ok=True)

# 设置环境变量
os.environ["TRANSFORMERS_CACHE"] = models_dir
os.environ["HF_HOME"] = models_dir
os.environ["HF_HUB_CACHE"] = models_dir
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

# 加载SAT数据集
print("正在加载SAT数据集...")
with open(os.path.join(sat_dir, "SAT_train_15000.json"), "r") as f:
    sat_data = json.load(f)

print(f"加载了 {len(sat_data)} 个SAT样本")

# 随机选择样本进行测试
num_samples = 10  # 测试10个样本
selected_samples = random.sample(sat_data, num_samples)

# 加载模型
print("正在加载Qwen视觉模型...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "remyxai/SpaceQwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)
processor = AutoProcessor.from_pretrained("remyxai/SpaceQwen2.5-VL-3B-Instruct")


正在加载SAT数据集...
加载了 15000 个SAT样本
正在加载Qwen视觉模型...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
# 准备结果统计
results = []
correct_count = 0

print(f"\n开始测试 {num_samples} 个SAT样本...")

for idx, sample in enumerate(tqdm(selected_samples, desc="测试进度")):
    # 准备输入
    image_paths = sample["images"]
    if not image_paths:
        print(f"样本 {idx} 没有图像路径，跳过")
        continue
    
    # 加载图像
    try:
        image_path = os.path.join(sat_dir, image_paths[0])
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"无法加载图像 {image_path}: {e}")
        # 尝试使用相对路径
        try:
            image = Image.open(image_paths[0]).convert("RGB")
        except:
            print(f"样本 {idx} 无法加载图像，跳过")
            continue
    
    # 重组问题，去掉前缀
    user_message = sample["messages"][0]["content"]
    if "<image>" in user_message:
        user_message = user_message.replace("<image>", "").strip()
    
    ground_truth = sample["messages"][1]["content"]
    
    # 准备消息格式
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": user_message},
            ],
        }
    ]
    
    # 处理输入
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)
    
    # 生成回答
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    model_response = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    # 检查回答是否正确 (简单的包含匹配)
    is_correct = ground_truth.lower() in model_response.lower()
    if is_correct:
        correct_count += 1
    
    # 保存结果
    result = {
        "index": idx,
        "question": user_message,
        "ground_truth": ground_truth,
        "model_response": model_response,
        "is_correct": is_correct
    }
    results.append(result)
    
    # 打印结果
    print(f"\n样本 {idx+1}:")
    print(f"问题: {user_message}")
    print(f"参考答案: {ground_truth}")
    print(f"模型回答: {model_response}")
    print(f"是否正确: {'✓' if is_correct else '✗'}")
    print("-" * 50)

# 计算准确率
accuracy = correct_count / num_samples
print(f"\n测试完成! 准确率: {accuracy:.2%} ({correct_count}/{num_samples})")
"""
# 保存结果
output_file = os.path.join(results_dir, f"sat_test_results_{num_samples}.json")
with open(output_file, "w", encoding="utf-8") as f:
    json.dump({
        "accuracy": accuracy,
        "correct_count": correct_count,
        "total_samples": num_samples,
        "results": results
    }, f, ensure_ascii=False, indent=2)

print(f"结果已保存到: {output_file}")"""


开始测试 10 个SAT样本...


测试进度:  10%|█         | 1/10 [00:08<01:12,  8.11s/it]


样本 1:
问题: Answer in natural language. Is black round side table to the right of pedestal sink with a wide basin? Choose between the following options: yes or no
参考答案: yes
模型回答: yes
是否正确: ✓
--------------------------------------------------


测试进度:  20%|██        | 2/10 [00:19<01:18,  9.83s/it]


样本 2:
问题: Answer in natural language. How many Stools are visible in the scene? Choose between the following options: 0, 4, 2, 3 or 1
参考答案: 1
模型回答: There is 1 Stool in the scene.
是否正确: ✓
--------------------------------------------------


测试进度:  30%|███       | 3/10 [00:33<01:23, 11.87s/it]


样本 3:
问题: Answer in natural language. Considering the relative positions, where is Window with respect to black shelving unit with two shelves? Choose between the following options: left or right
参考答案: left
模型回答: Window with respect to black shelving unit with two shelves is on the left.
是否正确: ✓
--------------------------------------------------


测试进度:  40%|████      | 4/10 [00:45<01:11, 11.85s/it]


样本 4:
问题: Answer in natural language. How many HousePlants are visible in the scene? Choose between the following options: 3, 1, 0, 4 or 2
参考答案: 1
模型回答: The number of HousePlants visible in the scene is 2.
是否正确: ✗
--------------------------------------------------


测试进度:  50%|█████     | 5/10 [01:01<01:06, 13.28s/it]


样本 5:
问题: Answer in natural language. Considering the relative positions, where is brown side table with a drawer with respect to room with bed covered in bedsheet? Choose between the following options: right or left
参考答案: right
模型回答: The brown side table with a drawer is on the right side of the room with the bed covered in a bedsheet.
是否正确: ✓
--------------------------------------------------


测试进度:  60%|██████    | 6/10 [01:02<00:37,  9.41s/it]


样本 6:
问题: Answer in natural language. Which point is closer to the camera taking this photo, point B  or point C? Choose between the following options: C or B
参考答案: C
模型回答: B
是否正确: ✗
--------------------------------------------------


测试进度:  70%|███████   | 7/10 [01:09<00:25,  8.48s/it]


样本 7:
问题: Answer in natural language. Which object is closer to the camera taking this photo, beautiful wall painting (highlighted by a blue box) or Window (highlighted by a green box)? Choose between the following options: beautiful wall painting (highlighted by a blue box) or Window (highlighted by a green box)
参考答案: Window (highlighted by a green box)
模型回答: Window (highlighted by a green box)
是否正确: ✓
--------------------------------------------------


测试进度:  80%|████████  | 8/10 [01:19<00:17,  8.83s/it]


样本 8:
问题: Answer in natural language. Which object is closer to the camera taking this photo, brown wooden chair with soft cushion (highlighted by a red box) or lowpoly bed with brown blanket (highlighted by a blue box)? Choose between the following options: lowpoly bed with brown blanket (highlighted by a blue box) or brown wooden chair with soft cushion (highlighted by a red box)
参考答案: brown wooden chair with soft cushion (highlighted by a red box)
模型回答: lowpoly bed with brown blanket (highlighted by a blue box)
是否正确: ✗
--------------------------------------------------


测试进度:  90%|█████████ | 9/10 [01:28<00:08,  8.93s/it]


样本 9:
问题: Answer in natural language. How many DiningTables are visible in the scene? Choose between the following options: 4, 2, 2, 3 or 0
参考答案: 2
模型回答: The number of DiningTables visible in the scene is 2.
是否正确: ✓
--------------------------------------------------


测试进度: 100%|██████████| 10/10 [01:30<00:00,  9.02s/it]


样本 10:
问题: Answer in natural language. Considering the relative positions, is white color sperical shape office chair (marked B) to the left or right of dresser aspelund (marked C)? Choose between the following options: right or left
参考答案: right
模型回答: left
是否正确: ✗
--------------------------------------------------

测试完成! 准确率: 60.00% (6/10)





'\n# 保存结果\noutput_file = os.path.join(results_dir, f"sat_test_results_{num_samples}.json")\nwith open(output_file, "w", encoding="utf-8") as f:\n    json.dump({\n        "accuracy": accuracy,\n        "correct_count": correct_count,\n        "total_samples": num_samples,\n        "results": results\n    }, f, ensure_ascii=False, indent=2)\n\nprint(f"结果已保存到: {output_file}")'