In [2]:
import json
import re
import numpy as np
from tqdm import tqdm
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import torch
from sklearn.metrics.pairwise import cosine_similarity

# 加载停用词
stop_words = set(stopwords.words('english'))
Vector_size = 1400
Window_size = 1000
Min_Count = 10
DataNum = '1800'

def load_jsonl(file_path):
    """加载 JSONL 文件并返回字典列表。"""
    with open(file_path, 'r', encoding='utf-8') as file:
        return [json.loads(line) for line in file]

def clean_text(text):
    """清洗文本，去掉 HTML 标签并转换为小写。"""
    text = re.sub(r'<.*?>', ' ', text)  # 去掉所有 HTML 标签
    text = re.sub(r'\s+', ' ', text)  # 合并多个空格为一个空格
    return text.lower()  # 转换为小写

def segment_text(text):
    """对文本进行分词并去掉停用词。"""
    cleaned_text = clean_text(text)
    words = word_tokenize(cleaned_text)
    return [word for word in words if word not in stop_words]

def preprocess_documents(documents):
    """将文档信息存储到字典中。"""
    doc_dict = {}
    for doc in tqdm(documents, desc="Processing documents"):
        doc_id = doc['document_id']
        doc_text = doc['document_text']
        doc_dict[doc_id] = doc_text
    return doc_dict

def get_vector_mean(model, words):
    """获取词向量的均值，忽略不在词汇表中的词。"""
    valid_words = [word for word in words if word in model.wv]
    if not valid_words:
        return None  # 如果没有有效词，则返回 None
    return torch.mean(torch.tensor(model.wv[valid_words]), dim=0).unsqueeze(0)

def preprocess_questions(questions, model):
    """将问题、答案和参考文档ID存储到字典中，同时生成问题向量。"""
    question_dict = {}
    for question in tqdm(questions, desc="Processing questions"):
        question_text = question['question']
        answer_text = question['answer']
        reference_doc_ids = question.get('document_id', [])
        
        # 计算问题向量
        question_vector = get_vector_mean(model, segment_text(question_text))
        
        question_dict[question_text] = {
            'answer': answer_text,
            'document_id': reference_doc_ids,
            'vector': question_vector  # 存储问题向量
        }
    return question_dict

def train_word2vec(documents):
    """训练 Word2Vec 模型，并显示进度。"""
    tokenized_docs = []
    for doc in tqdm(documents.values(), desc="Tokenizing documents"):
        tokenized_docs.append(segment_text(doc))
    
    model = Word2Vec(sentences=tokenized_docs, vector_size=Vector_size, window=Window_size, min_count=Min_Count, workers=32)
    return model

def Candidate_Calculation(question_vector, doc_dict):
    # 将文档向量转换为 NumPy 数组
    doc_vectors_array = torch.stack(list(doc_vectors.values())).numpy()
    doc_vectors_array = doc_vectors_array.squeeze()  # 去掉多余的维度

    # for question_text, question_data in tqdm(question_vectors.items(), desc="Validating accuracy"):
    #     question_vector = question_data['vector'].detach().numpy()  # 获取问题向量

    # 计算相似度
    similarities = cosine_similarity(question_vector, doc_vectors_array)
    
    # 找到最匹配的五个文档
    top_5_indices = similarities.argsort()[0][-5:][::-1]
    
    # 根据相对index找到doc_id
    top_5_doc_ids = [list(doc_dict.keys())[i] for i in top_5_indices]
    
    # 输出结果
    return top_5_doc_ids

def validate_accuracy(question_dict):
    # 判断 ['document_id_top5'] 是否包含 ['document_id_answer']
    correct = 0
    for question_text, question_data in tqdm(question_dict.items(), desc="Validating accuracy"):
        if question_data['document_id_answer'] in question_data['document_id_top5']:
            correct += 1
    accuracy = correct / len(question_dict) if question_dict else 0
    return accuracy

# 主程序
documents = load_jsonl('./data/documents'+DataNum+'.jsonl')
questions = load_jsonl('./data/train'+DataNum+'.jsonl')

# 预处理文档和问题
doc_dict = preprocess_documents(documents)
model = train_word2vec(doc_dict)
question_dict = preprocess_questions(questions, model)

# 生成文档向量
doc_vectors = {}
for doc_id, doc_text in tqdm(doc_dict.items(), desc="Generating document vectors"):
    doc_vec_mean = get_vector_mean(model, segment_text(doc_text))
    if doc_vec_mean is not None:
        doc_vectors[doc_id] = doc_vec_mean

save_path = './model/word2vec'+ str(DataNum) + '_' + str(Vector_size) + '_' + str(Window_size) + '_' + str(Min_Count)+'.model'
model.save(save_path)
print(f"Word2Vec model saved to {save_path}")
save_doc_vector_path = './data/doc_vector'+ str(DataNum) + '_' + str(Vector_size) + '_' + str(Window_size) + '_' + str(Min_Count) + '.jsonl'
save_ques_vector_path = './data/ques_vector'+ str(DataNum) + '_' + str(Vector_size) + '_' + str(Window_size) + '_' + str(Min_Count) + '.jsonl'

# 保存文档向量到 JSONL
with open(save_doc_vector_path, 'w', encoding='utf-8') as doc_file:
    for doc_id, doc_text in doc_dict.items():
        doc_vector = doc_vectors.get(doc_id).detach().numpy().tolist() if doc_id in doc_vectors else None
        doc_entry = {
            'document_id': doc_id,
            'document_text': doc_text,
            'document_vector': doc_vector
        }
        doc_file.write(json.dumps(doc_entry) + '\n')

# 保存问题向量到 JSONL，并将整个字典保存到变量 val_dict
# dictionary 
val_dict = {}
with open(save_ques_vector_path, 'w', encoding='utf-8') as ques_file:
    for question_text, question_data in question_dict.items():
        ques_entry = {
            'question_text': question_text,
            'question_answer': question_data['answer'],
            'document_id_answer': question_data['document_id'],
            'document_id_top5': Candidate_Calculation(question_data['vector'].detach().numpy(), doc_dict),
            'question_vector': question_data['vector'].detach().numpy().tolist() if question_data['vector'] is not None else None
        }
        val_dict[question_text] = ques_entry
        ques_file.write(json.dumps(ques_entry) + '\n')

# 验证准确性
print(validate_accuracy(val_dict))

Processing documents: 100%|██████████| 1801/1801 [00:00<?, ?it/s]
Tokenizing documents: 100%|██████████| 1801/1801 [01:11<00:00, 25.36it/s]
Processing questions: 100%|██████████| 530/530 [00:00<00:00, 6465.95it/s]
Generating document vectors: 100%|██████████| 1801/1801 [02:34<00:00, 11.69it/s]


Word2Vec model saved to ./model/word2vec1800_1400_1000_10.model


Validating accuracy: 100%|██████████| 530/530 [00:00<00:00, 530164.83it/s]

0.7509433962264151



