# Vectorization

In [37]:
import json
import torch
import os
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer

In [38]:
with open("../data/qa_clean_data.json", "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Number of records: {len(data)}")

print(data[:2])

Number of records: 221
[{'id': '00001', 'question': '墨尔本的公共交通系统包括哪些交通工具', 'answer': '火车电车巴士出租车', 'source': '维多利亚州政府', 'link': 'httpsliveinmelbourne.vic.gov.auzh-cnlivegetting-aroundpublic-transport', 'tags': ['交通'], 'creator': 'mengyue', 'created_at': '2025-03-25'}, {'id': '00002', 'question': '墨尔本交通指南', 'answer': '通用交通自驾游旅游观光', 'source': '澳大利亚旅游局', 'link': 'httpswww.australia.cnzh-cnplacesmelbourne-and-surroundsgetting-around-melbourne.html', 'tags': ['交通'], 'creator': 'mengyue', 'created_at': '2025-03-25'}]


## Bert-base-chinese

In [39]:
def vectorize_questions_bert(
    json_file="../data/qa_clean_data.json", 
    output_file="../retriever/qa_tensors_bert.pt", 
    model_name="bert-base-chinese"
):
    # Load data
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    questions = [item["question"] for item in data]

    # Load Chinese BERT tokenizer & model
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name)
    model.eval()

    embeddings = []
    batch_size = 16
    for i in range(0, len(questions), batch_size):
        batch = questions[i:i+batch_size]
        encoded = tokenizer(
            batch, 
            padding=True, 
            truncation=True, 
            max_length=64, 
            return_tensors="pt"
        )
        with torch.no_grad():
            outputs = model(**encoded)
        # Take [CLS] embedding for each question
        batch_vecs = outputs.last_hidden_state[:, 0, :]
        embeddings.append(batch_vecs)

    # Concatenate all embeddings
    all_vecs = torch.cat(embeddings, dim=0)

    # Save to retriever
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    torch.save(all_vecs, output_file)

    print(f"Saved {len(questions)} question embeddings to: {output_file}, shape={all_vecs.shape}")
    return embeddings

In [40]:
if __name__ == "__main__":
    vectorize_questions_bert()

Saved 221 question embeddings to: ../retriever/qa_tensors_bert.pt, shape=torch.Size([221, 768])


In [41]:
# Load the saved tensor for testing
vecs = torch.load("../retriever/qa_tensors_bert.pt")

print(type(vecs))        
print(vecs.shape)        
print(vecs[0][:10]) 

<class 'torch.Tensor'>
torch.Size([221, 768])
tensor([-0.0697,  0.4000, -0.7653, -0.0790,  0.5335, -0.5587,  0.0885, -0.6943,
        -0.0871,  0.1190])


## Sentence-transformers

In [42]:
def vectorize_questions_trans(
    json_file="../data/qa_clean_data.json", 
    output_file="../retriever/qa_tensors_trans.pt", 
    model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
):
    # Load dataset
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    questions = [item["question"] for item in data]

    # Load sentence-transformer model
    model = SentenceTransformer(model_name)

    # Encode all questions
    embeddings = model.encode(
        questions,
        batch_size=32,
        convert_to_tensor=True,
        show_progress_bar=True
    )
    
    # Save embeddings
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    torch.save(embeddings, output_file)

    print(f"Saved {len(questions)} embeddings to {output_file}, shape={embeddings.shape}")
    return embeddings

In [43]:
if __name__ == "__main__":
    vectorize_questions_trans()

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Saved 221 embeddings to ../retriever/qa_tensors_trans.pt, shape=torch.Size([221, 384])


In [44]:
# Load the saved embeddings generated by sentence-transformers
vecs = torch.load("../retriever/qa_tensors_trans.pt")

print(type(vecs))       
print(vecs.shape)        
print(vecs[0][:10])      

<class 'torch.Tensor'>
torch.Size([221, 384])
tensor([ 0.5112, -0.6336, -0.0050,  0.0788,  0.1777,  0.3748,  0.1842,  0.0313,
        -0.4515,  0.4147], device='mps:0')
