# Driving Exam Auto Tagging
Direct tagging with a VLM

## A. Load Question Data

### 1. Import the scraped question bank

In [1]:
from src.qb.question_bank import QuestionBank
from data_storage.database.json_database import LocalJsonDB

In [2]:
db = LocalJsonDB("data_storage/database/json_db/data.json",
                 "data_storage/database/json_db/images")
qb : QuestionBank = db.load()
print(qb.question_count())

2836


In [3]:
def set_up_question_chapters(qb: QuestionBank):
    for chapter_id in qb.list_chapters():
        for qid in qb.get_qids_by_chapter():
            question = qb.get_question(qid)
            question.set_chapter((chapter_id, qb.describe_chapter(chapter_id)))
set_up_question_chapters(qb)

AttributeError: 'QuestionBank' object has no attribute 'list_chapters'

### 2. Resize images
Images are resized to 256x256 with grey padding to maintain aspect ratio.

In [None]:
import os

from data_cleaning.img_reshaper import ImgSquarer

In [None]:
def resize_images(qb: QuestionBank, squarer: ImgSquarer, new_dir: str) -> None:
    for qid in qb.get_qid_list():
        question = qb.get_question(qid)
        if question.get_img_path() is not None:
            new_path = squarer.reshape(qid, qb.get_img_dir(), new_dir)
            question.set_img_path(new_path)
    qb.set_img_dir(new_dir)

In [None]:
IMG_DIR_256 = "data_cleaning/resized_imgs/img256"
squarer_256 = ImgSquarer(256)
# If the directory is empty, resize images.
if not os.listdir(IMG_DIR_256):
    print("Resizing images to 256x256...")
    resize_images(qb, squarer_256, IMG_DIR_256)
else:
    print("Images already resized to 256x256, skipping...")

## B. Auto Tagging with VLM

In [None]:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import torch
from qwen_vl_utils import process_vision_info

In [None]:
from typing import Dict, List, Any
from src.qb.question import Question

In [None]:
def get_prompt() -> str:
    return """你是一位驾考科目一领域的知识分类专家。你的核心任务是为给定的【题目】和【答案】提炼并归纳出 2 到 10 个核心关键词标签。

这个任务的目标是为后续的自动化聚类分析提供高质量的原始数据。因此，你生成的标签需要遵循以下原则：

---
### 标签生成原则

1.  **高度概括 (Abstractive)**: 标签应该是对考点的抽象和总结，而不仅仅是摘录原文的词语。例如，对于“行驶超过规定时速百分之四十”，一个好的标签是“超速行驶”，而不是“超过规定时速40%”。
2.  **简洁精炼 (Concise)**: 每个标签应该是一个简短的词组或名词，通常不超过6个汉字。
3.  **核心聚焦 (Core Focus)**: 标签必须直接反映题目最核心的知识点或考察场景。
4.  **术语通用 (Common Terminology)**: 尽量使用驾考领域最常用和通用的术语。例如，使用“记分”而不是“扣除分数”。
5.  **自由生成 (Free Generation)**: 你**不需要**遵循任何预设的列表，请根据你对题目的理解自由地生成最合适的标签。

---
### 输出要求

* 仔细阅读【题目】和【答案】，理解其完整含义和考点。
* 根据上述原则，生成2到10个最能代表题目的核心标签。
* 必须以JSON数组的格式输出标签，例如：["标签1", "标签2"]

---
### 示范

**输入 1:**
{'题目': '申请人因故不能按照预约时间参加考试的，应当提前多长时间申请取消预约？', '选项': {'A': '15日', 'B': '1日', 'C': '30日', 'D': '3日'}, '答案': 'B'}

**输出 1:**
[]
"""

In [None]:
print(get_prompt())

In [None]:
def format_question(question: Question) -> str:
    """ Format the question into a json string """
    dict_question = {"章节": f"{question.get_chapter()[0]}: {question.get_chapter()[1]}",
                     "题目": question.get_question(),
                     "选项": {},
                     "答案": ""}

    answer_choices = list(question.get_answers())
    answer_choices.sort()
    for i in range(0, len(answer_choices)):
        letter_code = chr(ord('A') + i)
        dict_question["选项"][letter_code] = answer_choices[i]
        if answer_choices[i] == question.get_correct_answer():
            dict_question["答案"] = letter_code
    return str(dict_question)

In [1]:
sample_lst = []
for chapter in qb.list_chapters():
    for qid in qb.get_qids_by_chapter(chapter):
        test_question = qb.get_question(qid)
        sample_lst.append(format_question(test_question))
        break

for question in sample_lst:
    print(format_question(question))

NameError: name 'qb' is not defined

In [None]:
def make_content(question: Question) -> List[Dict[str, Any]]:
    if question.get_img_path() is not None:
        return [
            {"type": "image",
             "image": question.get_img_path()},
            {"type": "text",
             "text": format_question(question)},
        ]
    else:
        return [
            {"type": "text",
             "text": format_question(question)}
        ]

In [None]:
def make_message(question: Question) -> List[Dict[str, Any]]:
    return [{"role": "system",
             "content": get_prompt()},
            {"role": "user",
             "content": make_content(question)}]

In [None]:
def get_qid_lst(qb: QuestionBank) -> List[str]:
    qid_lst = []
    for chapter in qb.list_chapters():
        for qid in qb.get_qids_by_chapter(chapter):
            qid_lst.append(qid)
    qid_lst.sort()
    return qid_lst

In [None]:
def make_messages(qb: QuestionBank) -> List[List[Dict[str, Any]]]:
    input_lst = []
    for qid in get_qid_lst(qb):
        question = qb.get_question(qid)
        input_lst.append(make_message(question))
    return input_lst

In [None]:
def make_inputs(messages, processor):
    texts = []
    for msg in messages:
        texts.append(processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    # inputs = inputs.to("cuda")
    return inputs

In [None]:
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map="cpu",
)

# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
min_pixels = 256*28*28
max_pixels = 258*28*28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

In [None]:
messages = make_messages(qb)

In [None]:
msgs_test = messages[:5]
inputs = make_inputs(msgs_test, processor)

In [None]:
generated_ids = model.generate(**inputs, max_new_tokens=128)

In [None]:
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]

In [None]:
output_texts = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)