# Driving Exam Auto Tagging
Direct tagging with a VLM

## A. Format Question Data

## 0. Set up the environment

Set the source path to the root of the project

In [1]:
import json
import os

In [2]:
SRC_PATH = "/Users/simonxu/Files/Projects/Drivetest_App/2_NLP Tag Creation/drivetest_tag_extraction/src"
os.chdir(SRC_PATH)

### 1. Set up the question bank

In [3]:
from entities.question_bank import QuestionBank
from data_access.local_json_db import LocalJsonDB
from data_formatting.data_formatter import DataFormatter, DataFormat

### i) Load the question bank

In [4]:
RAW_DATA_FILE = "data_storage/raw_database/data.json"
RAW_IMG_DIR = "data_storage/raw_database/images"

def load_data() -> QuestionBank:
    """ Load the question bank from the formatted data directory """
    raw_db = LocalJsonDB(RAW_DATA_FILE, RAW_IMG_DIR)
    return raw_db.load()

In [5]:
raw_qb = load_data()
print(raw_qb.question_count())

2836


### ii) Preprocessing

Images are reshaped to a standard size and format.

In [6]:
FORMATTED_IMG_DIR = "data_storage/formatted_database/images"
def format_data(raw_qb: QuestionBank, data_format: DataFormat) -> QuestionBank:
    """ Load the question bank from the formatted data directory """
    data_formatter = DataFormatter(data_format=data_format)
    new_qb = data_formatter.format_data(question_bank=raw_qb,
                                        new_img_dir=FORMATTED_IMG_DIR)
    return new_qb

In [7]:
%%time
INPUT_IMG_EXTENSION = "webp"
OUTPUT_IMG_EXTENSION = "jpg"

data_format = DataFormat(image_shape=(256, 256),
                         input_image_extension=INPUT_IMG_EXTENSION,
                         output_image_extension=OUTPUT_IMG_EXTENSION)
qb = format_data(raw_qb=raw_qb, data_format=data_format)
print(qb.question_count())

2836
CPU times: user 14.1 s, sys: 2.08 s, total: 16.2 s
Wall time: 16.7 s


### iii) Save the formatted question bank

In [8]:
FORMATTED_DB_FILE_PATH = "data_storage/formatted_database/data.json"
def save_formatted_data(question_bank: QuestionBank) -> None:
    """ Save the question bank to the specified file path """
    formatted_db = LocalJsonDB(FORMATTED_DB_FILE_PATH, FORMATTED_IMG_DIR)
    formatted_db.save(question_bank)

In [9]:
save_formatted_data(qb)

## C. Question Bank to Batch Request File

Turn the question bank into a jsonl file that can be used for making batch requests compatible with the OpenAI standard.

In [10]:
import datetime
import logging
from logging import Logger

from label_generator.batch_request_factory import BatchRequestFactory

In [11]:
def load_prompt() -> str:
    """ Load the prompt from the specified file path. """
    with open(PROMPT_FILE_PATH, 'r', encoding='utf-8') as file:
        prompt = file.read()
    return prompt

In [12]:
def make_logger(logging_directory: str, verbose: bool=False, debug: bool=False) -> Logger:
    """ Create a logger that logs to the specified directory. """
    log_filename, timestamp = _make_logger_name(logging_directory)
    logger = logging.getLogger(f"batch_request_{timestamp}")
    if debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    _add_handlers(log_filename, logger, verbose, debug)
    return logger

def _make_logger_name(logging_directory):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
    log_filename = os.path.join(logging_directory,
                                f"batch_request_{timestamp}.log")
    return log_filename, timestamp

def _add_handlers(log_filename, logger, verbose, debug):
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    formatter = _add_file_handler(log_filename, logger, debug)
    if verbose:
        _add_console_handler(formatter, logger)

def _add_file_handler(log_filename, logger, debug):
    file_handler = logging.FileHandler(log_filename)
    if debug:
        file_handler.setLevel(logging.DEBUG)
    else:
        file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    return formatter

def _add_console_handler(formatter, logger):
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

Specify model information and request URL.

In [13]:
LOGGING_DIRECTORY = "my_logs"
PROMPT_FILE_PATH = "data_storage/prompt_file/prompt.txt"
MODEL_NAME = "qwen-vl-max"
REQUEST_URL = "/v1/chat/completions"

In [14]:
%%time
batch_maker = BatchRequestFactory(
    question_bank=qb,
    prompt=load_prompt(),
    url=REQUEST_URL,
    model_name=MODEL_NAME,
    logger=make_logger(LOGGING_DIRECTORY, verbose=False, debug=True))
batch_request = batch_maker.make_batch_request()

CPU times: user 773 ms, sys: 338 ms, total: 1.11 s
Wall time: 1.11 s


In [15]:
REQUEST_FILE_PATH = "data_storage/batch_request_file/tagging_request.jsonl"

In [16]:
def clear_request_file():
    with open(REQUEST_FILE_PATH, 'w', encoding='utf-8') as file:
        json.dump({}, file)

In [17]:
def count_lines_in_file(file_path: str) -> int:
    """ Count the number of lines in a file. """
    with open(file_path, 'r', encoding='utf-8') as file:
        return sum(1 for _ in file)

In [18]:
batch_request.to_jsonl_file(REQUEST_FILE_PATH)
print(f"Number of lines in the request file: {count_lines_in_file(REQUEST_FILE_PATH)}")

Number of lines in the request file: 2836


# 2. Generate the Labels

In [None]:
from pathlib import Path
from openai import OpenAI

In [None]:
client = OpenAI(
    api_key=os.getenv("DASHSCOPE_API_KEY"),
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)

## a) Upload batch file

In [None]:
%%time
file_object = client.files.create(file=Path(REQUEST_FILE_PATH), purpose="batch")

In [None]:
print(file_object.model_dump_json())

## b) Create batch job

In [None]:
REQUEST_METADATA = {'ds_name':"科目一标签生成",
                    'ds_description':'为驾考科目一题目自动生成 "tags" 和 "keywords"。 其中"tags" 需要深入理解问题的测试内容，代表问题的知识点与考点。“keywords”需要提取问题中明确或隐含的关键词， 用来检索问题内容。'}

In [None]:
%%time
request_id = file_object.id
batch = client.batches.create(
    input_file_id=request_id,
    endpoint=REQUEST_URL,
    completion_window="24h",
    metadata=REQUEST_METADATA
)
print(batch)

Periodically check the status of the batch job.

In [None]:
from time import sleep

In [None]:
WAIT_TIME = 300 # 5 Minutes
IN_PROGRESS_STATUS_CODES = ["validating", "in_progress", "finalizing", "cancelling"]
ERROR_STATUS_CODES = ["failed", "expired", "cancelled"]

It may be easier to just check batch_status on their website

In [None]:
batch_status = client.batches.retrieve(batch.id)
while batch_status.status in IN_PROGRESS_STATUS_CODES:
    print(f"Batch job status: {batch_status.status}")
    sleep(WAIT_TIME)
    batch_status = client.batches.retrieve(batch.id)
print(f"Final job status: {batch_status.status}")

## c) Error handling

In [None]:
ERROR_FILE_PATH = "data_storage/tagging_results/error.jsonl"

### i) Clear the error file

In [None]:
# clear the error file if it exists
if os.path.exists(ERROR_FILE_PATH):
    with open(ERROR_FILE_PATH, 'w', encoding='utf-8') as file:
        file.write("")

In [None]:
print(batch_status)

### ii) Save the new errors

In [None]:
if batch_status.error_file_id is not None:
    content = client.files.content(batch_status.error_file_id)
    content.write_to_file(ERROR_FILE_PATH)
    print(f"完整的请求失败信息已保存至本地错误文件: {ERROR_FILE_PATH}")

## d) Retrieve result

In [None]:
RESULT_OUTPUT_PATH = "data_storage/tagging_results/result.jsonl"

### i) Clear the result file

In [None]:
if os.path.exists(RESULT_OUTPUT_PATH):
    with open(RESULT_OUTPUT_PATH, 'w', encoding='utf-8') as file:
        file.write("")

### ii) Save the result file

In [None]:
output_file = client.files.content(file_id=batch_status.output_file_id)
output_file.write_to_file(RESULT_OUTPUT_PATH)

In [None]:
count_lines_in_file(RESULT_OUTPUT_PATH)

## e) Archive result

In [None]:
RESULT_ARCHIVE_PATH = "data_storage/tagging_results/result_archive"
ERROR_ARCHIVE_PATH = "data_storage/tagging_results/error_archive"

In [None]:
def get_archive_paths() -> (str, str):
    """ Get the archive path for the result and error files. """
    os.makedirs(RESULT_ARCHIVE_PATH, exist_ok=True)
    os.makedirs(ERROR_ARCHIVE_PATH, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
    archive_path = os.path.join(RESULT_ARCHIVE_PATH, f"result_{timestamp}.jsonl")
    error_archive_path = os.path.join(ERROR_ARCHIVE_PATH, f"error_{timestamp}.jsonl")
    return archive_path, error_archive_path

In [None]:
import shutil

In [None]:
def archive_files():
    """ Archive the result file by copying it to the archive directory. """
    result_path, error_path = get_archive_paths()
    shutil.copy(RESULT_OUTPUT_PATH, result_path)
    shutil.copy(ERROR_FILE_PATH, error_path)

In [None]:
archive_files()

# 3) Parse the results

## a) Load labels into the question bank

In [56]:
from label_generator.response_parsing_pipeline import ResponseParsingPipeline
from label_generator.label_factory import MessageFormatConfig

In [57]:
RESULT_OUTPUT_PATH = "data_storage/tagging_results/result.jsonl"

In [58]:
message_format = MessageFormatConfig(output_end_tag="</JSON>",
                                     output_start_tag="<JSON>")

In [59]:
final_qb = raw_qb
response_parser = ResponseParsingPipeline(
    question_bank=final_qb,
    message_format=message_format,
    result_path=RESULT_OUTPUT_PATH)

In [60]:
response_parser.parse_and_load()

In [61]:
for qid in final_qb.get_qid_list()[:10]:
    question = final_qb.get_question(qid)
    print(f"Question ID: {qid}")
    print(f"Tags: {question.tags}")
    print(f"Keywords: {question.keywords}")
    print("-" * 40)

Question ID: 00428
Tags: ['交通信号-标志-禁令']
Keywords: ['交通标志', '禁令标志', '禁止直行', '禁止右转', '红色圆形']
----------------------------------------
Question ID: 00624
Tags: ['交通信号-标志-预告', '道路信息-交叉路口']
Keywords: ['蓝色标志', '南京路', '东北路', '前方500m', 'G2', '十字交叉路口预告']
----------------------------------------
Question ID: 0079e
Tags: ['交通信号-标志-会车让行', '安全驾驶-驾驶习惯-礼让行车']
Keywords: ['交通标志', '红色边框', '黑色箭头', '红色箭头', '会车', '停车让行']
----------------------------------------
Question ID: 008da
Tags: ['安全驾驶-基本原则-谨慎驾驶', '规则记忆型-安全理念']
Keywords: ['谨慎驾驶', '集中注意力', '仔细观察', '提前预防']
----------------------------------------
Question ID: 0092c
Tags: ['通行规定-速度限制-无标志公路限速', '法规记忆型-法定限速']
Keywords: ['公路', '最高速度', '40公里/小时', '双车道', '茂密树木']
----------------------------------------
Question ID: 0146a
Tags: ['安全驾驶-灯光使用-雾天', '规则记忆型-法规细节']
Keywords: ['雾天', '行车', '雾灯', '危险报警闪光灯', '开启']
----------------------------------------
Question ID: 019e1
Tags: ['交通信号-信号灯-绿灯通行规则', '情景判断型-交通信号理解']
Keywords: ['交通信号灯', '绿灯', '右转弯', '直行', '左转', '不能右转']
--

## b) Examine missing labels and implement them manually

In [62]:
from typing import List

In [63]:
def get_empty_labels(question_bank: QuestionBank) -> List[str]:
    """ Get a list of questions with empty tags or keywords. """
    empty_labels = []
    for qid in question_bank.get_qid_list():
        question = question_bank.get_question(qid)
        if not question.tags or not question.keywords:
            empty_labels.append(qid)
    return empty_labels

In [64]:
for qid in get_empty_labels(question_bank=qb):
    question = final_qb.get_question(qid)
    print(question, "\n")

In [65]:
q_e386c = final_qb.get_question("e386c")
q_e386c.set_keywords(["扣留车辆", "未悬挂号牌", "未购买交强险", "未携带行驶证", "未携带灭火器", "违法行为"])
q_e386c.set_tags(["违法处理-行政处罚-扣留车辆",
                  "规则记忆型-法规处罚"])

## c) Save the updated question bank

In [66]:
LABELED_DATA_DIR = "data_storage/labeled_database"
LABELED_DB_FILE_PATH = os.path.join(LABELED_DATA_DIR, "data.json")
FINAL_IMG_DIR = RAW_IMG_DIR
labeled_db = LocalJsonDB(db_file_path=LABELED_DB_FILE_PATH,
                         img_dir=FINAL_IMG_DIR)
labeled_db.save(final_qb)

True

In [67]:
final_qb = labeled_db.load()

In [68]:
for qid in final_qb.get_qid_list()[:10]:
    question = final_qb.get_question(qid)
    print(f"Question ID: {qid}")
    print(f"Tags: {question.tags}")
    print(f"Keywords: {question.keywords}")
    print("-" * 40)

Question ID: 00428
Tags: []
Keywords: []
----------------------------------------
Question ID: 00624
Tags: []
Keywords: []
----------------------------------------
Question ID: 0079e
Tags: []
Keywords: []
----------------------------------------
Question ID: 008da
Tags: []
Keywords: []
----------------------------------------
Question ID: 0092c
Tags: []
Keywords: []
----------------------------------------
Question ID: 0146a
Tags: []
Keywords: []
----------------------------------------
Question ID: 019e1
Tags: []
Keywords: []
----------------------------------------
Question ID: 01e32
Tags: []
Keywords: []
----------------------------------------
Question ID: 01ea3
Tags: []
Keywords: []
----------------------------------------
Question ID: 01fdd
Tags: []
Keywords: []
----------------------------------------
