In [None]:
import json
import re
import numpy as np
import nltk
import os
import string
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from scipy.sparse import csr_matrix

# Load stop words
nltk.download('punkt')
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
nltk.download('wordnet')
# read words from abbreviations.txt
with open('abbreviations.txt', 'r') as file:
    abbreviations = file.read().splitlines()

DataNum = ''
model_dir = './model/TF_IDF'
SaveDir = './data/TF_IDF_annotated/'
CHUNK_SIZE = 1000
Debug = 0

def load_jsonl(file_path):
    """Load a JSONL file and return a list of dictionaries."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return [json.loads(line) for line in file]

def load_dict(file_path):
    """Load a dictionary from a JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def clean_text(text):
    """Clean text by removing HTML tags, URLs, and converting to lowercase."""
    text = re.sub(r'<.*?>', ' ', text)
    text = re.sub(r'//[^\s]*', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.lower().strip()

def process_text(text):
    def is_valid_amount(word):
        # 支持负数和小数的金额格式
        if word.startswith('$') or word.startswith('-$'):
            return True
        return False

    def format_amount(word):
        # 去掉$符号和逗号
        cleaned_amount = word.replace('$', '').replace(',', '')
        return f"(Amount${cleaned_amount})"

    def is_valid_date(word):
        # 检查年份格式
        if word.isdigit() and len(word)==4 and 1400 <= int(word) <= 2150:
            return True
        return False

    def format_date(word):
        # 格式化年份
        return f"([Date]{word})"

    def is_valid_time(word):
        # 检查时间格式
        if word[-3:] == 'p.m' or word[-3:] == 'a.m':
            word = word[:-3]
        parts = word.split(':')
        # 支持时:分 或 时:分:秒
        if len(parts) == 2 or len(parts) == 3:
            hours, minutes = parts[0], parts[1]
            # 确保小时和分钟都是数字
            if hours.isdigit() and minutes.isdigit():
                if (0 <= int(hours) <= 23) and (0 <= int(minutes) <= 59):
                    # 如果有秒，检查秒的有效性
                    if len(parts) == 3:
                        seconds = parts[2]
                        if seconds.isdigit() and 0 <= int(seconds) <= 59:
                            return True
                    else:
                        return True
        return False

    def format_time(word):
        # 格式化时间
        hours = minutes = seconds = 0
        parts = word.split(':')
        flag = -1
        if word[-3:] == 'p.m':
            word = word[:-3]
            flag = 1
        elif word[-3:] == 'a.m':
            word = word[:-3]
            flag = 0
        if len(parts) == 3:
            hours, minutes, seconds = parts
            str = f"([Time]{hours}h{minutes}m{seconds}s)"
        elif len(parts) == 2:
            hours, minutes = parts
            str = f"([Time]{hours}h{minutes}m0s)"
        if flag == 1:
            str = str[:-1]
            str += 'p.m)'
        else:
            str = str[:-1]
            str += 'a.m)'
        return str
    def is_valid_temperature(word):
        # 检查温度格式
        if word.startswith('-'):
            word = word[1:]
        if word[:-2].replace('.', '').isdigit() and word[-1] in ['C', 'F', 'c', 'f']:
            return True
        return False

    def format_temperature(word):
        # 格式化温度
        return f"([Temperature]{word})"

    def is_valid_abbreviation(word):
        # 检查常见缩写
        # 使用 WordNet 查找缩写
        if word in abbreviations:
            return True

    def format_abbreviation(word):
        # 格式化缩写
        return f"([Abbreviation]{word})"  
    

    words = text.split()
    for i in range(len(words)):
        word = words[i]
        # remove punctuation  from word's ending if it is there
        word = word.rstrip(string.punctuation)
        if is_valid_amount(word):
            words[i] = format_amount(word)
        elif is_valid_date(word):
            words[i] = format_date(word)
        elif is_valid_time(word):
            words[i] = format_time(word)
        elif is_valid_temperature(word):
            words[i] = format_temperature(word)
        # elif is_valid_abbreviation(word):
        #     words[i] = format_abbreviation(word)

    return ' '.join(words)

def segment_text(text):
    """Tokenize text while preserving key information like amounts, dates, etc."""
    text = process_text(text)
    # 定义特殊模式
    special_terms = ['(Amount$', '([Date]', '([Time]', '([Temperature]', '([Abbreviation]']

    # 分词并处理特殊模式
    tokens = []
    i = 0
    while i < len(text):
        # 检查是否为特殊模式的开始
        if any(text.startswith(term, i) for term in special_terms):
            # 找到特殊模式的结束位置
            end = text.find(')', i) + 1
            if end > i:  # 如果找到了结束位置
                tokens.append(text[i:end])  # 将特殊模式添加到 tokens
                i = end  # 更新索引
            else:
                tokens.append(text[i])  # 如果没有找到，添加当前字符
                i += 1
        else:
            # 正常分词
            word = ''
            while i < len(text) and text[i] not in [' ', '(', ')']:
                #debug
                # print(text[i])
                word += text[i]
                i += 1
            if word not in [',','.','?','!',';','(',')','[',']','{','}','-','_','+','=','*','/','\\','|','&','^','%','$','#','@','!','~','`',':','<','>','"','\'']:
                tokens.append(word)
            # 跳过空格
            if word == '':
                i += 1
            while i < len(text) and text[i] == ' ':
                i += 1
    # 移除停用词
    return [word for word in tokens if word not in stop_words]

def preprocess_documents(documents):
    """Store document information in a dictionary."""
    doc_dict = {}
    for doc in tqdm(documents, desc="Processing documents"):
        doc_id = doc['document_id']
        doc_text = clean_text(doc['document_text'])
        doc_dict[doc_id] = doc_text
    return doc_dict

def preprocess_questions(questions, question_tfidf_matrix):
    """Store questions and their information in a dictionary."""
    question_dict = {}
    for idx, question in tqdm(enumerate(questions), desc="Processing questions"):
        question_text = question['question']
        answer_text = question['answer']
        reference_doc_ids = question.get('document_id', [])
        question_dict[question_text] = {
            'answer': answer_text,
            'document_id': reference_doc_ids,
            'vector': question_tfidf_matrix[idx]
        }
    return question_dict

# def compute_tfidf(text, vectorizer):
#     """Compute TF-IDF values."""
#     tfidf_matrix = vectorizer.fit_transform(text)
#     return tfidf_matrix, vectorizer.get_feature_names_out()
def compute_tfidf(text, vectorizer):
    """Compute TF-IDF values, ensuring special tags have IDF of 1."""
    tfidf_matrix = vectorizer.fit_transform(text)
    feature_names = vectorizer.get_feature_names_out()
    
    # Set IDF for special tags to 1
    for idx, feature in enumerate(feature_names):
        if feature.startswith('(Amount') or feature.startswith('([Date]') or \
           feature.startswith('([Time]') or feature.startswith('([Temperature]') or \
           feature.startswith('([Abbreviation]'):
            vectorizer.idf_[idx] = 1.0  # Set IDF to 1 for special terms

    return tfidf_matrix, feature_names

def Candidate_Calculation(question_vector, doc_vectors):
    """Calculate the top 5 matching documents."""
    question_vector = question_vector.reshape(1, -1)
    similarities = cosine_similarity(question_vector, doc_vectors)
    top_5_indices = similarities.argsort()[0][-5:][::-1]
    return top_5_indices

def validate_accuracy(question_dict, doc_dict, doc_vectors):
    """Validate accuracy: Check if candidate documents contain reference document IDs."""
    correct = 0
    total = len(question_dict)
    
    for question_text, question_data in tqdm(question_dict.items(), desc="Validating accuracy"):
        question_vector = question_data['vector']
        top_5_indices = Candidate_Calculation(question_vector.toarray(), doc_vectors)
        top_5_doc_ids = [list(doc_dict.keys())[i] for i in top_5_indices]
        
        reference_doc_ids = question_data['document_id']
        if isinstance(reference_doc_ids, list):
            if any(doc_id in top_5_doc_ids for doc_id in reference_doc_ids):
                correct += 1
        else:
            if reference_doc_ids in top_5_doc_ids:
                correct += 1
    
    accuracy = correct / total if total > 0 else 0
    return accuracy

def save_sparse_vector(vector, file_path):
    """Save a sparse vector in a compact format."""
    row, col = vector.nonzero()
    data = vector.data
    entry = {
        'indices': col.tolist(),
        'values': data.tolist(),
        'shape': vector.shape,
        'indptr': vector.indptr.tolist()
    }
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(entry, f)

def save_dict_to_json(data_dict, file_path):
    """Save a dictionary to a JSON file."""
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data_dict, f)

# Main program
documents = load_jsonl(f'./data/documents{DataNum}.jsonl')
questions = load_jsonl(f'./data/train{DataNum}.jsonl')

# Preprocess documents and questions
doc_dict = preprocess_documents(documents)
vectorizer = TfidfVectorizer(tokenizer=segment_text, max_df=0.95)
all_text = list(doc_dict.values()) + [q['question'] for q in questions]
tfidf_matrix = vectorizer.fit_transform(all_text)

# Split TF-IDF matrix into documents and questions
num_doc = len(doc_dict)
doc_tfidf_matrix = tfidf_matrix[:num_doc]
question_tfidf_matrix = tfidf_matrix[num_doc:]

# Process questions to include TF-IDF vectors
question_dict = preprocess_questions(questions, question_tfidf_matrix)

# Create directory for vector files if it doesn't exist
vector_dir = f'{SaveDir}vector{DataNum}'
os.makedirs(vector_dir, exist_ok=True)

# Save document vectors to JSONL
save_doc_vector_path = f'{SaveDir}doc_vector{DataNum}.jsonl'
with open(save_doc_vector_path, 'w', encoding='utf-8') as doc_file:
    for i, (doc_id, doc_text) in enumerate(doc_dict.items()):
        doc_vector = doc_tfidf_matrix[i]
        vector_file_path = f'{vector_dir}/vector_{doc_id}.json'
        save_sparse_vector(doc_vector, vector_file_path)
        doc_entry = {
            'document_id': doc_id,
            'document_text': doc_text,
            'document_vector_path': vector_file_path
        }
        doc_file.write(json.dumps(doc_entry) + '\n')

# Save question vectors to JSONL
save_ques_vector_path = f'{SaveDir}ques_vector{DataNum}.jsonl'
with open(save_ques_vector_path, 'w', encoding='utf-8') as ques_file:
    for question_text, question_data in question_dict.items():
        vector_file_path = f'{vector_dir}/vector_ques{hash(question_text)}.json'
        save_sparse_vector(question_data['vector'], vector_file_path)
        ques_entry = {
            'question_text': question_text,
            'question_answer': question_data['answer'],
            'document_id_answer': question_data['document_id'],
            'question_vector_path': vector_file_path
        }
        ques_file.write(json.dumps(ques_entry) + '\n')

# Create a dictionary for TF-IDF feature names
feature_names = vectorizer.get_feature_names_out()
tfidf_feature_dict = {i: feature for i, feature in enumerate(feature_names)}
save_dict_to_json(tfidf_feature_dict, f'{model_dir}/dict_{DataNum}.json')

# Validate accuracy
accuracy = validate_accuracy(question_dict, doc_dict, doc_tfidf_matrix)
print("Accuracy:", accuracy)

# Reconstruction accuracy validation
def reconstruct_text(sparse_vector, feature_dict, top_n=100):
    """Reconstruct text from a sparse vector using a feature dictionary."""
    indices = sparse_vector['indices']
    values = sparse_vector['values']
    index_to_word = {index: feature_dict[str(index)] for index in indices}
    sorted_indices = np.argsort(values)[-top_n:][::-1]
    reconstructed_words = [index_to_word[indices[i]] for i in sorted_indices]
    return ' '.join(reconstructed_words)

def calculate_reconstruction_accuracy(doc_dict, feature_dict, top_n=100):
    """Calculate the accuracy of text reconstruction."""
    correct = 0
    total = len(doc_dict)

    for i, (doc_id, original_text) in enumerate(doc_dict.items()):
        sparse_vector_path = f'{SaveDir}vector{DataNum}/vector_{doc_id}.json'
        sparse_vector = load_dict(sparse_vector_path)
        reconstructed_text = reconstruct_text(sparse_vector, feature_dict, top_n)

        original_words = set(original_text.split())
        reconstructed_words = set(reconstructed_text.split())
        original_words = original_words.difference(stop_words)

        common_words = original_words.intersection(reconstructed_words)
        accuracy = len(common_words) / min(len(original_words), len(reconstructed_words)) if min(len(original_words), len(reconstructed_words)) > 0 else 0
        
        if accuracy > 0.98:  # Threshold for considering it a "match"
            correct += 1
        else:
            print(f"Doc ID: {doc_id}")
            print(original_words.difference(common_words))
            print(reconstructed_words.difference(common_words))
            print(f"Accuracy: {accuracy}")
            print()

    print(f"Correct: {correct}", f"Total: {total}")
    return correct / total if total > 0 else 0

# Load the necessary data for reconstruction accuracy
doc_vector_data = load_jsonl(f'{SaveDir}doc_vector{DataNum}.jsonl')
doc_dict = {doc['document_id']: doc['document_text'] for doc in doc_vector_data}

tfidf_feature_dict = load_dict(f'{model_dir}/dict_{DataNum}.json')
# Calculate reconstruction accuracy
reconstruction_accuracy = calculate_reconstruction_accuracy(doc_dict, tfidf_feature_dict)
print("Reconstruction Accuracy:", reconstruction_accuracy)