# Drive Test Tag Generation With BERTopic
Generate tags for the written portion of the chinese driving exam using BERTopic.

## 1. Load Data
Loading data from a local database into a question bank class.

In [1]:
from src.qb.question import Question
from src.qb.question_bank import QuestionBank
from data_storage.database.json_database import LocalJsonDB

db = LocalJsonDB("data_storage/database/json_db/data.json",
                 "data_storage/database/json_db/images")
qb : QuestionBank = db.load()
print(qb.question_count())

2836


## 2. Format Data
Although the Siglip2 model can handle images of different sizes, I will still resize all images to common sizes.

In [2]:
from data_cleaning.img_reshaper import ImgSquarer

IMG_DIR_256 = "data_cleaning/resized_imgs/img256"
IMG_DIR_512 = "data_cleaning/resized_imgs/img512"

squarer_256 = ImgSquarer(256)
squarer_512 = ImgSquarer(512)

In [3]:
def resize_images(qb: QuestionBank, squarer: ImgSquarer, new_dir: str) -> None:
    for chapter_id in qb.get_all_chapter_num():
        for qid in qb.get_qids_by_chapter(chapter_id):
            question = qb.get_question(qid)
            if question.get_img_path() is not None:
                question.set_img_path(squarer.reshape(qid, qb.get_img_dir(), new_dir))

In [4]:
import os
# If the directory is empty, resize images.
if not os.listdir(IMG_DIR_256):
    print("Resizing images to 256x256...")
    resize_images(qb, squarer_256, IMG_DIR_256)
else:
    print("Images already resized to 256x256, skipping...")

Images already resized to 256x256, skipping...


## 3. Create Multimodal Embeddings
Create multimodal embeddings for the questions using a Siglip2 model.

In [5]:
# Library Imports
from transformers import AutoModel, AutoProcessor

# Local Imports
from embedder.siglip2_qb_embedder import Siglip2QBEmbedder

### a) Load/Download the Siglip2 Model
We will be using "google/siglip2-base-patch16-256" for this task.

In [6]:
MODEL_NAME = "google/siglip2-base-patch16-256"

model = AutoModel.from_pretrained(MODEL_NAME)
processor = AutoProcessor.from_pretrained(MODEL_NAME, use_fast=False)

### b) Create embeddings

#### i) Define a logger

In [7]:
import logging
from logging import Logger
from datetime import datetime
import os

LOGGING_PATH = "logs"

def get_logger(name: str) -> Logger:
    # Create logger
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)

    # Create a file handler with timestamp in filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    file_handler = logging.FileHandler(
        os.path.join(LOGGING_PATH, f"{name}_{timestamp}.log")
    )

    # Create formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    file_handler.setFormatter(formatter)

    # Add handler to logger
    logger.addHandler(file_handler)

    return logger
embedder_logger = get_logger("embedder")

#### ii) Create the embedder

In [8]:
custom_embedder = Siglip2QBEmbedder(model, processor, embedder_logger)

#### iii) Generate embeddings

In [12]:
EMBEDDINGS_DIR = "data_storage/embedding_dir"
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
embedding_file = os.path.join(EMBEDDINGS_DIR, "siglip2_embeddings.npz")

In [9]:


if embedding_file in os.listdir(EMBEDDINGS_DIR):
    print(f"Embeddings already exist at {embedding_file}, skipping generation.")
else:
    print("Generating embeddings...")
    # Generate embeddings for the question bank
    embeddings = custom_embedder.encode_qb(qb)

#### iv) Save embeddings

In [13]:
import numpy as np

def save_embeddings(embeddings, file_path):
    np.savez(embedding_file, **{str(qid): embeddings[qid] for qid in embeddings})

if not os.path.exists(embedding_file):
    print(f"Saving embeddings to {embedding_file}...")
    save_embeddings(embeddings, embedding_file)
else:
    print(f"Embeddings file {embedding_file} already exists, skipping save.")

Saving embeddings to data_storage/embedding_dir/siglip2_embeddings.npz...


## 4. Generate Tags with BERTopic

### a) Load Embeddings

In [None]:
# To load the embeddings later:
def load_embeddings(file_path):
    loaded = np.load(file_path)
    return {key: loaded[key] for key in loaded.files}
embeddings = load_embeddings(embedding_file)