In [1]:
""" 0. set-up part:  import necessary libraries and set up environment """

import pandas as pd
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag, word_tokenize
from collections import Counter, defaultdict
import numpy as np
import math
import copy
import itertools
import matplotlib.pyplot as plt
import matplotlib as mpl

import joblib
from joblib import Parallel, delayed
from threading import Thread

import os
import pickle
import time

import operator
from functools import reduce
import json
import cProfile

# download nltk data once time
# nltk.download('punkt')
# nltk.download('stopwords')
# nltk.download('wordnet')
# nltk.download('averaged_perceptron_tagger')
# nltk.download('omw-1.4')
# nltk.download('punkt_tab')
# nltk.download('averaged_perceptron_tagger_eng')

#  chinese character support in matplotlib
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS' 'SimHei' 'DejaVu Sans']  
plt.rcParams['axes.unicode_minus'] = False  


In [2]:
""" 1.1 Data Preprocessing: load data, clean text, lemmatization, remove low-frequency words"""

# Map POS tags to WordNet format， Penn Treebank annotation: fine-grained (45 tags), WordNet annotation: coarse-grained (4 tags: a, v, n, r)
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return 'a'  # 形容词
    elif treebank_tag.startswith('V'):
        return 'v'  # 动词
    elif treebank_tag.startswith('N'):
        return 'n'  # 名词
    elif treebank_tag.startswith('R'):
        return 'r'  # 副词
    else:
        return 'n'  # 默认名词

# Text cleaning and lemmatization preprocessing function
def clean_and_lemmatize(text):
    if pd.isnull(text):
        return []
    text = text.lower()
    text = re.sub(r'[^a-z\s]', '', text)  # Remove non-alphabetic characters using regex
    tokens = word_tokenize(text)
    tokens = [w for w in tokens if w not in stop_words]
    pos_tags = pos_tag(tokens)
    lemmatized = [lemmatizer.lemmatize(w, get_wordnet_pos(pos)) for w, pos in pos_tags]
    return lemmatized  

#-----------------Load data----------------
data = pd.read_excel('./data/raw/papers_CM.xlsx', usecols=['PaperID', 'Abstract', 'Keywords', 'Year'])

stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

# clean and lemmatize the abstracts
data['Lemmatized_Tokens'] = data['Abstract'].apply(clean_and_lemmatize)

# count word frequencies
all_tokens = [word for tokens in data['Lemmatized_Tokens'] for word in tokens]
word_counts = Counter(all_tokens)

# set a minimum frequency threshold for valid words
min_freq = 10
valid_words = set([word for word, freq in word_counts.items() if freq >= min_freq])

# remove rare words based on frequency threshold
def remove_rare_words(tokens):
    return [word for word in tokens if word in valid_words]

data['Filtered_Tokens'] = data['Lemmatized_Tokens'].apply(remove_rare_words)

# join tokens back into cleaned abstracts
data['Cleaned_Abstract'] = data['Filtered_Tokens'].apply(lambda x: " ".join(x))

# create a cleaned DataFrame with relevant columns
cleaned_data = data[['PaperID', 'Year', 'Cleaned_Abstract']]
cleaned_data = cleaned_data[~(cleaned_data['PaperID'] == 57188)] # this paper has no abstract
cleaned_data = cleaned_data.reset_index(drop=True) 
cleaned_data.insert(0, 'Document_ID', range(len(cleaned_data))) 
abstract_list = cleaned_data['Cleaned_Abstract'].apply(lambda x: x.split()).tolist()

corpus = {doc_id: abstract_list for doc_id, abstract_list in enumerate(abstract_list)}
# cleaned_data.to_csv('./data/processed/cleaned_data.xlsx', index=False, encoding='utf-8-sig')

In [3]:
""" 1.2 Corpus Analysis: create corpus word frequency table, document-word matrix, and generate corpus statistics """

# Create corpus word frequency table
def create_corpus_word_frequency_table(corpus):
    # count all words in the corpus
    all_words = [word for doc in corpus.values() for word in doc]
    word_freq_counter = Counter(all_words)
    
    # create a DataFrame for word frequencies
    word_freq_df = pd.DataFrame([
        {'Word': word, 'Frequency': freq}
        for word, freq in word_freq_counter.items()
    ]).sort_values('Frequency', ascending=False).reset_index(drop=True)
    
    # add statistics like total words, unique words, and percentages
    total_words = len(all_words)
    unique_words = len(word_freq_counter)
    
    word_freq_df['Percentage'] = (word_freq_df['Frequency'] / total_words * 100).round(4)
    word_freq_df['Cumulative_Percentage'] = word_freq_df['Percentage'].cumsum().round(4)
    
    # add word distribution across documents like document count and percentage
    word_doc_count = {}
    for word in word_freq_counter.keys():
        doc_count = sum(1 for doc in corpus.values() if word in doc)
        word_doc_count[word] = doc_count
    
    word_freq_df['Document_Count'] = word_freq_df['Word'].map(word_doc_count)
    word_freq_df['Document_Percentage'] = (word_freq_df['Document_Count'] / len(corpus) * 100).round(4)
    
    # categorize frequency into High, Medium, Low, Very Low
    def categorize_frequency(freq, total):
        if freq >= total * 0.01:  # >1%
            return 'High'
        elif freq >= total * 0.001:  # 0.1%-1%
            return 'Medium'
        elif freq >= total * 0.0001:  # 0.01%-0.1%
            return 'Low'
        else:
            return 'Very Low'
    
    word_freq_df['Frequency_Category'] = word_freq_df['Frequency'].apply(
        lambda x: categorize_frequency(x, total_words)
    )
    
    # add ranking of words based on frequency
    word_freq_df['Rank'] = range(1, len(word_freq_df) + 1)
    
    # return DataFrame with selected columns in a specific order
    word_freq_df = word_freq_df[[
        'Rank', 'Word', 'Frequency', 'Percentage', 'Cumulative_Percentage',
        'Document_Count', 'Document_Percentage', 'Frequency_Category'
    ]]
    
    return word_freq_df

# Create document-word matrix
def create_document_word_matrix(corpus, cleaned_data=None):
    # get all unique words (sorted by frequency, keeping only the top 1000 most common words)
    all_words = [word for doc in corpus.values() for word in doc]
    word_counter = Counter(all_words)
    top_words = [word for word, _ in word_counter.most_common(1000)]
    
    # create a document-word matrix, where each row corresponds to a document and each column corresponds to a word
    # the value is the frequency of that word in the document
    doc_word_matrix = []
    
    for doc_id, doc in corpus.items():
        doc_counter = Counter(doc)
        row = [doc_counter.get(word, 0) for word in top_words]
        doc_word_matrix.append(row)
    
    # create DataFrame from the document-word matrix
    df = pd.DataFrame(doc_word_matrix, columns=top_words)
    
    # insert Document_ID and other metadata if available
    if cleaned_data is not None and len(cleaned_data) == len(corpus):
        df.insert(0, 'Document_ID', range(len(corpus)))
        df.insert(1, 'PaperID', cleaned_data['PaperID'].values)
        df.insert(2, 'Year', cleaned_data['Year'].values)
    else:
        df.insert(0, 'Document_ID', range(len(corpus)))
    
    return df

# Generate corpus statistics like total documents, total words, unique words, vocabulary richness, average document length, etc.
def generate_corpus_statistics(corpus, cleaned_data=None):

    # word frequency statistics
    total_docs = len(corpus)
    all_words = [word for doc in corpus.values() for word in doc]
    total_words = len(all_words)
    unique_words = len(set(all_words))
    
    # document length statistics
    doc_lengths = [len(doc) for doc in corpus.values()]
    
    # count word frequencies and return statistics
    word_counter = Counter(all_words)
    
    statistics = {
        'basic_stats': {
            'total_documents': total_docs,
            'total_words': total_words,
            'unique_words': unique_words,
            'vocabulary_richness': unique_words / total_words,
            'average_doc_length': np.mean(doc_lengths),
            'median_doc_length': np.median(doc_lengths),
            'min_doc_length': min(doc_lengths),
            'max_doc_length': max(doc_lengths),
            'std_doc_length': np.std(doc_lengths)
        },
        'frequency_distribution': {
            'words_appearing_once': sum(1 for freq in word_counter.values() if freq == 1),
            'words_appearing_2_5_times': sum(1 for freq in word_counter.values() if 2 <= freq <= 5),
            'words_appearing_6_20_times': sum(1 for freq in word_counter.values() if 6 <= freq <= 20),
            'words_appearing_more_than_20': sum(1 for freq in word_counter.values() if freq > 20),
        },
        'top_words': word_counter.most_common(20)
    }
    
    return statistics

#-------------------create word frequency summary-----------------
word_freq_table = create_corpus_word_frequency_table(corpus)

print(f"\n📊 Corpus word frequency information:")
print(f"total words: {word_freq_table['Frequency'].sum():,}")
print(f"num of unique words: {len(word_freq_table):,}")
print(f"average num of words for each doc: {word_freq_table['Frequency'].mean():.2f}")
print(f"standard deviation of word frequency among docs: {word_freq_table['Frequency'].std():.2f}")

print(f"\n🏆 Top 20 High-frequency words:")
print(word_freq_table.head(20).to_string(index=False))

print(f"\n📈 Word Rank Information:")
freq_category_stats = word_freq_table['Frequency_Category'].value_counts()
print(freq_category_stats)

#-------------------create document-word matrix and generate corpus statistics----------------
# create document-word matrix (optional, suitable for smaller datasets)

doc_word_matrix = create_document_word_matrix(corpus, cleaned_data)
print(f"\n📋 create doc-words matrix, shape of matrix is : {doc_word_matrix.shape}")

# generate detailed statistics
corpus_stats = generate_corpus_statistics(corpus, cleaned_data)

print(f"\n📊 Corpus detailed statistics info:")
print("=" * 50)
for category, stats in corpus_stats.items():
    print(f"\n{category.upper()}:")
    if isinstance(stats, dict):
        for key, value in stats.items():
            if isinstance(value, float):
                print(f"  {key}: {value:.4f}")
            else:
                print(f"  {key}: {value}")
    elif isinstance(stats, list):
        print("  Top words:")
        for word, freq in stats:
            print(f"    {word}: {freq}")

# ---------------------save statistics tables-------------------
print(f"\n💾 save statistics tables...")

# word_freq_table.to_csv('corpus_word_frequency_table.csv', index=False, encoding='utf-8')
print("✅ save <- corpus_word_frequency_table.csv")

if 'doc_word_matrix' in locals():
    # doc_word_matrix.to_csv('document_word_matrix.csv', index=False, encoding='utf-8')
    print("✅  save <- document_word_matrix.csv")

# 保存统计摘要
stats_df = pd.DataFrame([
    {'Metric': key, 'Value': value} 
    for category_stats in corpus_stats.values() 
    if isinstance(category_stats, dict)
    for key, value in category_stats.items()
])

# stats_df.to_csv('corpus_statistics_summary.csv', index=False, encoding='utf-8')
print("✅ save <- corpus_statistics_summary.csv")


📊 Corpus word frequency information:
total words: 83,202
num of unique words: 1,490
average num of words for each doc: 55.84
standard deviation of word frequency among docs: 104.60

🏆 Top 20 High-frequency words:
 Rank          Word  Frequency  Percentage  Cumulative_Percentage  Document_Count  Document_Percentage Frequency_Category
    1        method       1598      1.9206                 1.9206             654              67.4227               High
    2         model       1554      1.8677                 3.7883             579              59.6907               High
    3       element       1112      1.3365                 5.1248             513              52.8866               High
    4           use        999      1.2007                 6.3255             572              58.9691               High
    5       propose        823      0.9892                 7.3147             510              52.5773             Medium
    6     numerical        799      0.9603            

In [4]:
"""2. Core logic function: Chain Rule Process (CRP)"""
class Node:
    """
    Tree node class for hierarchical topic modeling using nested Chinese Restaurant Process (nCRP).
    Each node represents a topic in the hierarchy, with the tree structure representing
    the nested relationship between general and specific topics.
    """
    last_node_id = 0
    total_node = 0
    node_with_id = {}

    def __init__(self, parent=None, layer=0):
        self.node_id = Node.last_node_id # Unique identifier for this node
        Node.last_node_id += 1
        Node.total_node += 1
        self.layer = layer # layer (int): Depth level in the hierarchy (0=root, 1=first level, etc.)
        self.children = [] # children (list): List of child Node objects
        self.parent = parent # parent (Node): Reference to parent node (None for root)
        self.docs_list = [] # record index of documents reaching this node
        Node.node_with_id[self.node_id] = self 

    def add_child(self):
        child = Node(parent=self, layer=self.layer+1)
        self.children.append(child)
        return child
    
    def remove_child(self, child):
        self.children.remove(child)
        Node.node_with_id[child.node_id] = None 
        child.parent = None 
        Node.total_node -= 1

def nCRP(corpus, depth, gamma):
    """
    Nested Chinese Restaurant Process (nCRP) for initializing hierarchical topic structure.
    
    This function implements the generative process for creating a tree structure where:
    1. Each document follows a path from root to a leaf node
    2. At each level, documents choose to either create a new topic or join an existing one
    3. Words in documents are randomly assigned to topics along their path
    """
    # Initialize the root Node class each time nCRP is called
    Node.last_node_id = 0
    Node.total_node = 0
    Node.node_with_id = {}
    
    """
     Args:
        corpus (dict): Document collection {doc_id: [word1, word2, ...]}
        depth (int): Maximum depth of the topic hierarchy (number of levels)
        gamma (float): Concentration parameter controlling topic creation probability
                      Higher gamma -> more likely to create new topics
                      
    Returns: [root_node, path_list, doc_path, doc_word_allocation] where:
        - root_node (Node): Root node of the topic tree
        - path_list (dict): {leaf_node_id: [node0, node1, ...]} - Complete paths from root to each leaf
        - doc_path (dict): {doc_id: leaf_node_id} - Maps each document to its assigned leaf node
        - doc_word_allocation (dict): {doc_id: [layer0, layer1, ...]} - Word-to-layer assignments
    """

    root_node = Node()
    path_list = {} # {leaf_node_id: [node0, node2,...]} - record each path from root to leaf nodes
    doc_word_allocation = {} # {doc_id: {word: layer}} - record word allocation for each document
    doc_path = {} # [leaf_node_id, leaf_node_id, ...] - record only the leaf node id of the path， list indexed by doc_id
    
    for c, doc in corpus.items(): # c is the index, d is the document
        # print(doc)
        # all docs starts from the root node
        path = [root_node]
        root_node.docs_list.append(c)

        for i in range(1, depth):
            # chose node based on CRP
            parent_node = path[i-1]

            CRP_probs = [gamma/(gamma + len(parent_node.docs_list) - 1)] # choose a new node: gamma/(gamma + n - 1)
            for child in parent_node.children:
                CRP_probs.append(len(child.docs_list)/(gamma + len(parent_node.docs_list) - 1))

            chosen_index = np.random.choice(len(CRP_probs), p=CRP_probs) 
            if chosen_index == 0: # create a new node
                current_node = parent_node.add_child()
            else: # chose an existing node
                current_node = parent_node.children[chosen_index - 1]

            path.append(current_node)
            current_node.docs_list.append(c)

        if path[-1].node_id not in path_list.keys():
            path_list[path[-1].node_id] = path
        doc_path[c] = path[-1].node_id

        # assign the words location to the document
        word_allocation = []
        for word in doc:
            allocate_layer = np.random.randint(0,depth) # randomly allocate word to a layer
            word_allocation.append(allocate_layer)

        doc_word_allocation[c] = word_allocation
    return [root_node, path_list, doc_path, doc_word_allocation]

In [5]:
"""3. Core logic functions: Gibbs sampling, and relevant sub-function: word distribution, likelihood calculation, etc."""
def _create_default_dict_int():
    return defaultdict(int)

def aggregate_words(corpus, doc_word_allocation):
    """ 
    Convert document word-layer assignments into hierarchical node word count.

    Args:
        corpus: {doc_id: [word1, word2, ...]} - Document collection
        doc_word_allocation : {doc_id: [layer0, layer1, ...]} - Word-to-layer assignments
    
    Returns:
        doc_node_allocation: {doc_id: {layer: {word: count}}} - Nested word counts by document and layer
    """
    doc_node_allocation = {}

    for doc_id, doc in corpus.items(): # c is the index, d is the document
        allocation = doc_word_allocation[doc_id]
        
        # Use the named function here
        layer_counts = defaultdict(_create_default_dict_int)
        
        for word, layer in zip(doc, allocation):
            layer_counts[layer][word] += 1

        doc_node_allocation[doc_id] = layer_counts
    return doc_node_allocation

def node_word_distribution(doc_node_allocation, doc_path, path_list, exclude_docs=None):
    """
    Calculate global node-word distribution by aggregating word counts from all documents.
    
    Args:
        doc_node_allocation (dict): {doc_id: {layer: {word: count}}} - Document word counts by layer
        doc_path (dict): {doc_id: leaf_id} - Document to leaf node mapping
        path_list (dict): {leaf_id: [Node1, Node2, ...]} - Complete paths from root to leaf
        exclude_docs (set, optional): Document IDs to exclude from calculation
    
    Returns:
        node_word_dist: {node_id: {word: count}} - Word count distribution for each node
    """
    node_word_dist = defaultdict(lambda: defaultdict(int))
    exclude_docs = set() if exclude_docs is None else set(exclude_docs)
    
    for doc_id, leaf_id in doc_path.items():
        if doc_id in exclude_docs:
            continue
            
        path = path_list[leaf_id]
        
        doc_allocation = doc_node_allocation[doc_id]
        
        for node in path:
            node_layer = node.layer
            if node_layer in doc_allocation:
                for word, count in doc_allocation[node_layer].items():
                    node_word_dist[node.node_id][word] += count
    
    return node_word_dist

def calc_node_likelihood(compare_dist, target_dist, eta, len_W):
    """
    Calculate likelihood of generating target word distribution from comparison node 
    to any tree node using Dirichlet-multinomial model, providing a probability estimate to choose the best path.
   
    Args:
        compare_dist: {word: count} - Current word distribution of the comparison node
        target_dist: {word: count} - Target word distribution to be generated  
        eta (int): Dirichlet prior parameter (smoothing parameter)
        len_W (int): Vocabulary size
    
    Returns:
        float: Likelihood probability (or -inf for numerical overflow)
    """
    sum_A = sum(compare_dist.values()) + eta * len_W
    sum_B = sum(compare_dist.values()) + sum(target_dist.values()) + eta * len_W
    
    lgamma_sum_A = math.lgamma(sum_A)
    lgamma_sum_B = math.lgamma(sum_B)
    
    lgamma_prod_A = 0.0
    lgamma_prod_B = 0.0
    
    for word, count in target_dist.items():
        comp_val = compare_dist.get(word, 0)
        A = comp_val + eta
        B = comp_val + count + eta
        
        lgamma_prod_A += math.lgamma(A)
        lgamma_prod_B += math.lgamma(B)

    try:
        log_likelihood = (lgamma_sum_A - lgamma_prod_A) + (lgamma_prod_B - lgamma_sum_B)
        return math.exp(log_likelihood)
    except (OverflowError, ValueError):
        return float('-inf')  # 返回负无穷表示极小概率    

def create_new_path(base_node, doc_id, depth):
    new_path = []
    current = base_node
    while current:
        new_path.insert(0, current)
        current = current.parent
    
    current = base_node
    for _ in range(base_node.layer, depth-1):
        new_node = current.add_child()
        new_path.append(new_node)
        current = new_node
        
    return current.node_id, new_path

def exclude_doc_from_node_dist(global_node_word_dist, doc_node_allocation, doc_path, path_list, doc_id):
    doc_path_lst = path_list[doc_path[doc_id]]
    for node in doc_path_lst:
        if node.node_id in global_node_word_dist:
            for word, count in doc_node_allocation[doc_id].get(node.layer, {}).items():
                global_node_word_dist[node.node_id][word] -= count
                if global_node_word_dist[node.node_id][word] <= 0:
                    del global_node_word_dist[node.node_id][word]
    return global_node_word_dist

def add_doc_to_node_dist(global_node_word_dist, doc_node_allocation, doc_path, path_list, doc_id):
    """将文档的词分布添加到全局词分布中"""
    doc_path_lst = path_list[doc_path[doc_id]]
    for node in doc_path_lst:
        node_layer = node.layer
        if node_layer in doc_node_allocation[doc_id]:
            for word, count in doc_node_allocation[doc_id][node_layer].items():
                if node.node_id not in global_node_word_dist:
                    global_node_word_dist[node.node_id] = defaultdict(int)
                global_node_word_dist[node.node_id][word] += count
    return global_node_word_dist

def Gibbs_sampling(corpus, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, global_node_word_dist,
                   gamma, eta, alpha, depth, iteration, iter_start=None, iter_end=None):
    W = set(itertools.chain.from_iterable(corpus.values()))
    all_jump_records_this_iteration = []
    all_detail_records_this_iteration = []

    for doc_id, doc in corpus.items():
        # print('doc_id', doc_id)
        current_path = path_list[doc_path[doc_id]]
        current_node_allocation = doc_node_allocation[doc_id]
        node_word_dist = exclude_doc_from_node_dist(global_node_word_dist, doc_node_allocation, doc_path, path_list, doc_id)
        # print('g2',global_node_word_dist[0])

        # isolate doc_c from current situation
        nodes_to_remove = []
        for node in current_path[::-1]:
            node.docs_list.remove(doc_id)
            if len(node.docs_list) == 0 and node != root_node:
                nodes_to_remove.append(node)
        
        if nodes_to_remove:
            del path_list[doc_path[doc_id]]
            for node in nodes_to_remove:
                if node.parent:  # 再次检查以确保安全
                    node.parent.remove_child(node)

        # sample path
        path_prior_lst = {}
        path_likelihood_lst = {}

        target_prior = {}
        target_likelihood = {}
        for path_id, path in path_list.items():
            node_prior_lst = [1.0]
            node_likelihood_lst = []
            
            for node in path:
                if node.parent:
                    prior = (len(node.docs_list)) / (gamma + len(node.parent.docs_list)) 
                    node_prior_lst.append(prior)
                
                if node.layer in current_node_allocation:
                    likelihood = calc_node_likelihood(node_word_dist[node.node_id], current_node_allocation[node.layer], eta, len(W))
                    node_likelihood_lst.append(likelihood)
                else:
                    node_likelihood_lst.append(0)

            target_prior[path_id] = node_prior_lst
            target_likelihood[path_id] = node_likelihood_lst

            path_prior = reduce(operator.mul, (x for x in node_prior_lst if x != 0), 1)
            path_likelihood = reduce(operator.mul, (x for x in node_likelihood_lst if x != 0), 1)
            path_prior_lst[path_id] = path_prior
            path_likelihood_lst[path_id] = path_likelihood

        multiplied_path_dict = {key: path_prior_lst[key] * path_likelihood_lst[key] for key in path_likelihood_lst}
        
        # create new path,
        new_path_prior_lst = {}
        new_path_likelihood_lst = {}

        new_target_prior = {}
        new_target_likelihood = {}

        for node_id, node in Node.node_with_id.items():
            if node != None and node.layer < depth-1:
                new_node_prior_lst = [gamma / (gamma + len(node.docs_list))]
                
                if 0 in current_node_allocation:
                    new_node_likelihood_lst = [calc_node_likelihood(node_word_dist[0], current_node_allocation[0], eta, len(W))]
                else:
                    new_node_likelihood_lst = [0]
                
                temp_node = node
                while temp_node.parent:
                    new_node_prior_lst.insert(0, (len(temp_node.docs_list)) / (gamma + len(temp_node.parent.docs_list)))             
                
                    if temp_node.layer in current_node_allocation:
                        new_node_likelihood_lst.insert(1, calc_node_likelihood(node_word_dist[temp_node.node_id], current_node_allocation[temp_node.layer], eta, len(W)))
                    else:
                        new_node_likelihood_lst.insert(1, 0)

                    temp_node = temp_node.parent

                for layer in range(node.layer+1, depth):
                    if layer in current_node_allocation:
                        new_node_likelihood_lst.append(calc_node_likelihood({}, current_node_allocation[layer], eta, len(W)))

                new_target_prior[node_id] = new_node_prior_lst
                new_target_likelihood[node_id] = new_node_likelihood_lst
                new_path_prior = reduce(operator.mul, (x for x in new_node_prior_lst if x != 0), 1)
                new_path_likelihood = reduce(operator.mul, (x for x in new_node_likelihood_lst if x != 0), 1)
                new_path_prior_lst[node_id] = new_path_prior
                new_path_likelihood_lst[node_id] = new_path_likelihood
        
        multiplied_new_path_dict = {f'create{key}': new_path_prior_lst[key] * new_path_likelihood_lst[key] for key in new_path_likelihood_lst}
        all_probs = {**multiplied_path_dict, **multiplied_new_path_dict}

        total_prob = sum(v for v in all_probs.values())
        if total_prob > 0:
            normalized_probs = {k: v/total_prob for k, v in all_probs.items()}
        else:
            normalized_probs = {k: 1.0/len(all_probs) for k in all_probs}
        # print(normalized_probs)
        
        chosen_path = np.random.choice(list(normalized_probs.keys()),p=list(normalized_probs.values()))
        if chosen_path.startswith('create'):
            base_node = Node.node_with_id[int(chosen_path[6:])]
            leaf_id, added_path = create_new_path(base_node, doc_id, depth)
            path_list.update({leaf_id:added_path})
            doc_path[doc_id] = leaf_id
        else:
            leaf_id = int(chosen_path)
            added_path = path_list[int(chosen_path)]
            doc_path[doc_id] = int(chosen_path)

        """check doc jump"""
        current_doc_jump_record = {} # 为当前文档创建
        current_doc_detail_record = {} # 为当前文档创建

        if iter_start <= iteration <= iter_end and str(current_path[-1].node_id) != chosen_path:
            # print(f"Iteration {iteration}: Document {doc_id} changed path from {current_path[-1].node_id} to {added_path[-1].node_id}")
            
            current_doc_detail_record = {
                'iteration': iteration,
                'doc_id': doc_id,
                'deleted_path_list': {leaf_id: [node.node_id for node in path_nodes] for leaf_id, path_nodes in path_list.items()}, # 注意深拷贝的性能影响
                'doc_path': {k:v for k, v in doc_path.items()},
                'doc_word_allocation': list(doc_word_allocation[doc_id]),
                'doc_node_allocation': {layer: {word: count for word, count in word_counts.items()} 
                                        for layer, word_counts in doc_node_allocation[doc_id].items()}
            }

            if chosen_path.startswith('create'):
                # print(f"!!! Created new path: {[nd.node_id for nd in added_path]}")
                chosen_path_prob = normalized_probs[chosen_path]
                values = list(normalized_probs.values())
                sorted_values = sorted(values, reverse=True)
                rank = sorted_values.index(chosen_path_prob) + 1 
                # print(f"Path creation rank: {rank} out of {len(values)}")

                current_doc_jump_record = {
                    'iteration': iteration,
                    'doc_id': doc_id,
                    'old_leaf_id': current_path[-1].node_id,
                    'old_path': [n.node_id for n in current_path],
                    'new_leaf_id': added_path[-1].node_id,
                    'new_path': [n.node_id for n in added_path],
                    'origal_probs':all_probs,
                    'normalized_probs': normalized_probs,
                    'chosen_path_prob': chosen_path_prob,
                    'rank': f'{rank} out of {len(values)}',
                    'create_path': True
                }

            else:
                # print(f"Reused existing path: {[nd.node_id for nd in added_path]}")
                chosen_path_prob = normalized_probs[int(chosen_path)]
                values = list(normalized_probs.values())
                sorted_values = sorted(values, reverse=True)
                rank = sorted_values.index(chosen_path_prob) + 1  
                # print(f"Path reuse rank: {rank} out of {len(values)}")
                # print('---------------------------------------')

                current_doc_jump_record = {
                    'iteration': iteration,
                    'doc_id': doc_id,
                    'old_leaf_id': current_path[-1].node_id,
                    'old_path_ids': [n.node_id for n in current_path],
                    'new_leaf_id': added_path[-1].node_id,
                    'new_path_ids': [n.node_id for n in added_path],
                    'origal_probs':all_probs,
                    'normalized_probs': normalized_probs,
                    'chosen_path_prob': chosen_path_prob,
                    'rank': f'{rank} out of {len(values)}',
                    'create_path': False
                }

            if current_doc_jump_record: # 确保记录不为空
                all_jump_records_this_iteration.append(current_doc_jump_record)
            if current_doc_detail_record: # 确保记录不为空
                all_detail_records_this_iteration.append(current_doc_detail_record)
        
        for node in added_path:
                node.docs_list.append(doc_id)
        
        """sample topic"""
        # node_word_dist_update = node_word_distribution(doc_node_allocation, doc_path, path_list, exclude_docs=None)
        node_word_dist_update = add_doc_to_node_dist(global_node_word_dist, doc_node_allocation, doc_path, path_list, doc_id)
        # print('g3',global_node_word_dist[0])
        
        doc_word2node = doc_word_allocation[doc_id]
        current_path = path_list[doc_path[doc_id]]
        update_doc_word2node = []
        for word, old_layer in zip(doc, doc_word2node):
            doc_node_allocation[doc_id][old_layer][word] -= 1
            node_word_dist_update[current_path[old_layer].node_id][word] -= 1
            topic_probs = {}

            for layer, topic in enumerate(current_path):
                word_in_topic = node_word_dist_update[topic.node_id].get(word,0)
                topic_in_doc = sum(doc_node_allocation[doc_id].get(layer,{}).values())

                word_prop = (word_in_topic+eta)/(sum(node_word_dist_update[topic.node_id].values())+len(W)*eta)
                topic_prop = (topic_in_doc+alpha)/(len(doc)+depth*alpha)
                topic_probs[layer] = word_prop*topic_prop
            
            total_topic_prob = sum(v for v in topic_probs.values() if v > 0)
            if total_topic_prob > 0:
                normalized_topic_probs = {k: v/total_topic_prob for k, v in topic_probs.items()}
            else:
                normalized_topic_probs = {k: 1.0/len(topic_probs) for k in topic_probs}

            chosen_layer = np.random.choice(list(normalized_topic_probs.keys()),p=list(normalized_topic_probs.values()))
            update_doc_word2node.append(chosen_layer)
            doc_node_allocation[doc_id][chosen_layer][word] += 1
            node_word_dist_update[current_path[chosen_layer].node_id][word] += 1

        doc_word_allocation[doc_id] = update_doc_word2node
        # print('g4',global_node_word_dist[0])

    return all_jump_records_this_iteration, all_detail_records_this_iteration

In [6]:
class Recorder:
    """Gibbs采样迭代记录器"""
    
    def __init__(self, corpus, depth, eta, alpha):
        self.corpus = corpus
        self.depth = depth
        self.eta = eta
        self.alpha = alpha
        self.iteration_records = []
        self.vocab = set(itertools.chain.from_iterable(corpus.values()))
        
    def record_iteration(self, iteration_num, root_node, path_list, doc_path, 
                        doc_word_allocation, doc_node_allocation, global_node_word_dist, jump_record=[], detail_record=[],
                        newly_created_paths=None, iter_start_for_detailed_log=float('inf')): # 添加新参数并设默认值
        """记录单次迭代的所有信息"""
        
        should_record_word_node_details_to_file = (iteration_num >= iter_start_for_detailed_log)

        # 计算改变路径的文档数
        changed_docs_count = 0
        if iteration_num > 0:  # 对第一轮之后的迭代进行统计
            previous_doc_paths = {doc['document_id']: doc['leaf_node_id'] 
                                for doc in self.iteration_records[-1]['doc_path_assignments']}
            changed_docs_count = sum(1 for doc_id, leaf_id in doc_path.items() 
                                if doc_id in previous_doc_paths and previous_doc_paths[doc_id] != leaf_id)
        
        # 1. 记录文档路径分配
        # file_name: gibbs_iteration_records_document_paths - 
        doc_path_records = []
        for doc_id, leaf_id in doc_path.items():
            doc_path_records.append({
                'iteration': iteration_num,
                'document_id': doc_id,
                'leaf_node_id': leaf_id,
                'path_created_this_iteration': leaf_id in (newly_created_paths or [])
            })
        
        # 2. 记录路径/树结构 
        # path_list: {leaf_id: [Node1, Node2, ...]}
        # doc_path: {doc_id: leaf_id}
        # file_name:gibbs_iteration_records_path_structures - 
        path_structure_records = []
        for leaf_id, path_nodes in path_list.items():
            node_ids = [node.node_id for node in path_nodes]
            
            docs_in_this_path = [doc_id for doc_id, assigned_leaf_id in doc_path.items()
                                if assigned_leaf_id == leaf_id]
            
            path_record = {
                'iteration': iteration_num,
                'leaf_node_id': leaf_id,
                'document_count': len(docs_in_this_path),
                'documents_in_path': docs_in_this_path,
                'path_created_this_iteration': leaf_id in (newly_created_paths or [])
            }

            for i in range(self.depth):
                layer_key = f'layer_{i}_node_id'
                if i < len(node_ids):
                    path_record[layer_key] = node_ids[i]
                else:
                    path_record[layer_key] = None 
            
            path_structure_records.append(path_record),
        
        # 3. 记录词语分配 (仅在详细记录窗口内填充)
        # file_name: gibbs_iteration_records_word_allocations - 
        # doc_word_allocation: {doc_id: [layer_for_word1, layer_for_word2,...]}
        word_allocation_records = []
        # attention: should_record_word_node_details_to_file = (iteration_num >= iter_start_for_detailed_log)
        if should_record_word_node_details_to_file:
            for doc_id, word_assignments in doc_word_allocation.items():
                doc_words = self.corpus[doc_id]
                for word_idx, (word, layer) in enumerate(zip(doc_words, word_assignments)):
                    word_allocation_records.append({
                        'iteration': iteration_num,
                        'document_id': doc_id,
                        'word_index': word_idx,
                        'word': word,
                        'assigned_layer': layer,
                        'leaf_node_id': doc_path[doc_id],
                        'assigned_node_id': path_list[doc_path[doc_id]][layer].node_id
                    })
        
        # 4. 计算节点词分布 (始终需要为log-likelihood计算)
        # file_name: gibbs_iteration_records_node_word_distributions - 
        # doc_node_allocation: ?
        # node_word_dist = self._calculate_node_word_distribution(
        #     doc_node_allocation, doc_path, path_list
        # )
        # node_word_dist: # {node_id: {word: count}}
        #   记录节点词分布 (仅在详细记录窗口内填充)
        node_word_records = []
        # attention: should_record_word_node_details_to_file = (iteration_num >= iter_start_for_detailed_log)
        if should_record_word_node_details_to_file:
            for node_id, current_word_dist in global_node_word_dist.items(): # 使用计算得到的 node_word_dist
                for word, count in current_word_dist.items():
                    node_word_records.append({
                        'iteration': iteration_num,
                        'node_id': node_id,
                        'word': word,
                        'count': count
                    })
        
        # 5. 计算生成概率/log-likelihood (始终使用计算出的 node_word_dist)
        # file_name: gibbs_iteration_records_iteration_summaries - 
        log_likelihood = self._calculate_log_likelihood(
            doc_path, path_list, doc_word_allocation, global_node_word_dist
        )
        
        # 6. 记录整体统计信息,
        # file_name: gibbs_iteration_records_iteration_summaries -
        iteration_summary = {
            'iteration': iteration_num,
            'total_paths': len(path_list),
            'total_documents': len(doc_path),
            'log_likelihood': log_likelihood,
            'changed_docs_count': changed_docs_count, 
            'newly_created_paths': len(newly_created_paths or []),
            'avg_path_size': np.mean([len([doc_id for doc_id, assigned_leaf_id in doc_path.items() 
                                            if assigned_leaf_id == leaf_id]) for leaf_id in path_list.keys()]) if path_list else 0,
            'max_path_size': max([len([doc_id for doc_id, assigned_leaf_id in doc_path.items() 
                                        if assigned_leaf_id == leaf_id]) for leaf_id in path_list.keys()]) if path_list else 0,
            'min_path_size': min([len([doc_id for doc_id, assigned_leaf_id in doc_path.items() 
                                        if assigned_leaf_id == leaf_id]) for leaf_id in path_list.keys()]) if path_list else 0,
        #     'doc_path': doc_path,
        #     'doc_node_allocation':doc_node_allocation
        }

        # 7.记录跳跃记录和详细记录
        # file_name: gibbs_iteration_records_jump_records - 
        # jump_record: [{iteration, doc_id, old_leaf_id, old_path, new
        iteration_jump_record = []
        if jump_record:
            iteration_jump_record = jump_record
        
        iteration_detail_record = []
        if detail_record:
            iteration_detail_record = detail_record
            
        self.iteration_records.append({
            'iteration': iteration_num,
            'doc_path_assignments': doc_path_records,
            'path_structures': path_structure_records,
            'word_allocations': word_allocation_records, # 如果不满足条件，则为空列表
            'node_word_distributions': node_word_records, # 如果不满足条件，则为空列表
            'iteration_summary': iteration_summary,
            'jump_records_list': iteration_jump_record,
            'detail_records_list': iteration_detail_record
        })
        
        return iteration_summary
    
    # def _calculate_node_word_distribution(self, doc_node_allocation, doc_path, path_list):
    #     """计算节点词分布"""
    #     node_word_dist = defaultdict(lambda: defaultdict(int))
        
    #     for doc_id, leaf_id in doc_path.items():
    #         path = path_list[leaf_id]
    #         doc_allocation = doc_node_allocation[doc_id]
            
    #         for node in path:
    #             node_layer = node.layer
    #             if node_layer in doc_allocation:
    #                 for word, count in doc_allocation[node_layer].items():
    #                     node_word_dist[node.node_id][word] += count
        
    #     return node_word_dist

    def _calculate_log_likelihood(self, doc_path, path_list, 
                                    doc_word_allocation, # {doc_id: [layer_for_word1, layer_for_word2,...]}
                                    node_word_dist):     # {node_id: {word: count}}
        """
        计算在当前文档词语层级分配和路径下，重新生成所有文档词语的对数概率。
        log P(Words | WordLayerAssignments, Paths, eta)
        """
        total_log_likelihood = 0.0
        vocab_size = len(self.vocab)
        
        # 预计算每个节点的总词数
        node_total_words_map = defaultdict(int)
        for node_id, dist in node_word_dist.items():
            node_total_words_map[node_id] = sum(dist.values())

        for doc_id, words_in_doc in self.corpus.items():
            if not words_in_doc:
                continue

            current_doc_word_layer_assignments = doc_word_allocation[doc_id]
            current_doc_path_nodes = path_list[doc_path[doc_id]] # List of Node objects

            for i, word in enumerate(words_in_doc):
                assigned_layer = current_doc_word_layer_assignments[i]

                # P(word | assigned_node, eta)
                # = (count_of_this_word_in_node + eta) / (total_words_in_node + vocab_size * eta)
                if assigned_layer >= len(current_doc_path_nodes):
                    # print(f\Warning: Likelihood calc - doc_id {doc_id}, word '{word}' assigned_layer {assigned_layer} out of bounds for path length {len(current_doc_path_nodes)}.\)
                    total_log_likelihood += -float('inf') # Penalize for inconsistent assignment
                    continue

                assigned_node_object = current_doc_path_nodes[assigned_layer]
                assigned_node_id = assigned_node_object.node_id
                
                count_word_in_assigned_node = node_word_dist.get(assigned_node_id, {}).get(word, 0)
                total_words_in_assigned_node = node_total_words_map.get(assigned_node_id, 0)

                denominator_val = total_words_in_assigned_node + vocab_size * self.eta
                if denominator_val <= 0: 
                    log_prob_word_generation = -float('inf') 
                else:
                    numerator_val = count_word_in_assigned_node + self.eta
                    if numerator_val <=0: # Should not happen if eta > 0
                        log_prob_word_generation = -float('inf')
                    else:
                        log_prob_word_generation = math.log(numerator_val) - math.log(denominator_val)
                
                total_log_likelihood += log_prob_word_generation
                
        return total_log_likelihood
    
    def save_to_files(self, base_filename="iteration", last_n_iterations=20):
        """保存所有记录到CSV文件"""
        
        if not self.iteration_records:
            print("No iteration records to save.")
            return
        
        # 1. 保存文档路径分配记录 (保持不变)
        all_doc_paths = []
        for record in self.iteration_records:
            all_doc_paths.extend(record['doc_path_assignments'])
        
        doc_path_df = pd.DataFrame(all_doc_paths)
        doc_path_df.to_csv(f'{base_filename}_document_paths.csv', index=False, encoding='utf-8')
        print(f"✅ doc-path allocation is saved : {base_filename}_document_paths.csv")
    
        # 2. 保存路径结构记录 - 处理列表字段
        # 假设 path_record (来自 record['path_structures']) 已经包含了 layer_X_node_id 列
        all_path_structures = []
        for record in self.iteration_records:
            for path_record_item in record['path_structures']: # path_record_item 是包含分层信息的字典
                path_record_copy = path_record_item.copy()
                all_path_structures.append(path_record_copy)
        
        path_structure_df = pd.DataFrame(all_path_structures)
        path_structure_df.to_csv(f'{base_filename}_path_structures.csv', index=False, encoding='utf-8')
        print(f"✅ structure of tree path is saved: {base_filename}_path_structures.csv")
        
        # 5. 保存迭代摘要 (保持不变)
        iteration_summaries = [record['iteration_summary'] for record in self.iteration_records]
        summary_df = pd.DataFrame(iteration_summaries)
        summary_df.to_csv(f'{base_filename}_iteration_summaries.csv', index=False, encoding='utf-8')
        print(f"✅ info recoreded in iteration is saved: {base_filename}_iteration_summaries.csv")
        
        # 6. 额外保存一个详细的路径-文档映射表 - 修改此处以适应分层结构   
        path_document_mapping = []
        for record in self.iteration_records:
            for path_record_item in record['path_structures']: # path_record_item 是包含分层信息的字典
                iteration_num = path_record_item['iteration']
                leaf_id = path_record_item['leaf_node_id']

                for doc_id_in_path in path_record_item['documents_in_path']:
                    mapping_entry = {
                        'iteration': iteration_num,
                        'leaf_node_id': leaf_id,
                        'document_id': doc_id_in_path,
                    }
                    # 添加分层节点ID，直到模型定义的深度 self.depth
                    # 确保 self.depth 在 __init__ 中被设置
                    if hasattr(self, 'depth'):
                        for i in range(self.depth):
                            layer_key = f'layer_{i}_node_id'
                            mapping_entry[layer_key] = path_record_item.get(layer_key) # 从 path_record_item 获取
                    else:
                        # 如果 self.depth 不可用，可以尝试从 path_record_item 的键推断，但这不太理想
                        # 或者只记录实际存在的层级
                        pass # 根据实际情况处理 self.depth 不可用的情况
                    
                    path_document_mapping.append(mapping_entry)
        
        path_doc_mapping_df = pd.DataFrame(path_document_mapping)
        path_doc_mapping_df.to_csv(f'{base_filename}_path_document_mapping.csv', index=False, encoding='utf-8')
        print(f"✅ doc-path mapping is saved: {base_filename}_path_document_mapping.csv")
        
        filtered_records_for_detail = self.iteration_records[-last_n_iterations:]
        
        # 3. 保存词语分配记录 (保持不变)
        all_word_allocations = []
        for record in filtered_records_for_detail:
            all_word_allocations.extend(record['word_allocations'])
        
        word_allocation_df = pd.DataFrame(all_word_allocations)
        word_allocation_df.to_csv(f'{base_filename}_word_allocations.csv', index=False, encoding='utf-8')
        print(f"✅ doc's words allocation to node is saved: {base_filename}_word_allocations.csv")
        
        # 4. 保存节点词分布记录 (保持不变)
        all_node_words = []
        for record in filtered_records_for_detail:
            all_node_words.extend(record['node_word_distributions'])
        
        node_word_df = pd.DataFrame(all_node_words)
        node_word_df.to_csv(f'{base_filename}_node_word_distributions.csv', index=False, encoding='utf-8')
        print(f"✅ node word distribution is saved: {base_filename}_node_word_distributions.csv")
        
        # 7. 保存跳跃记录
        all_jump_records = []
        for record in filtered_records_for_detail:
            if 'jump_records_list' in record and record['jump_records_list']: # Check if the key exists and is not empty
                all_jump_records.extend(record['jump_records_list']) # extend since it's a list of records from Gibbs
        
        jump_records_df = pd.DataFrame() # Initialize as empty DataFrame
        if all_jump_records: # Only create DataFrame if there are records
            jump_records_df = pd.DataFrame(all_jump_records)
            # Potentially convert complex objects in jump_records_df to strings if they cause issues with CSV,
            # For example, if 'all_probs_summary' is a dict:,
            if 'all_probs_summary' in jump_records_df.columns:
                    jump_records_df['all_probs_summary_str'] = jump_records_df['all_probs_summary'].astype(str)
            jump_records_df.to_csv(f'{base_filename}_jump_records.csv', index=False, encoding='utf-8')
            print(f"✅ jump record info is saved: {base_filename}_jump_records.csv")
        else:
            print("ℹ️ no jumpy record info...")

        # 8. 保存详细记录 (如果需要)
        all_detail_records = []
        for record in filtered_records_for_detail:
            if 'detail_records_list' in record and record['detail_records_list']:
                all_detail_records.extend(record['detail_records_list'])
        
        detail_records_df = pd.DataFrame() # Initialize
        if all_detail_records:
            detail_records_df = pd.DataFrame(all_detail_records)
            # Similar to jump_records, handle complex objects if any before saving to CSV
            # For example, if 'delete_path_list' contains Node objects or complex dicts:
            # detail_records_df['delete_path_list_str'] = detail_records_df['delete_path_list'].astype(str)
            # Be cautious with saving very large structures like full path_list or doc_node_allocation per event.
            detail_records_df.to_csv(f'{base_filename}_detail_records.csv', index=False, encoding='utf-8')
            print(f"✅ other detailed info is saved: {base_filename}_detail_records.csv")
        else:
            print("ℹ️ no other detailed record info...")
        
        return {
            'doc_path_df': doc_path_df,
            'path_structure_df': path_structure_df,
            'word_allocation_df': word_allocation_df,
            'node_word_df': node_word_df,
            'summary_df': summary_df,
            'path_doc_mapping_df': path_doc_mapping_df,
            'jump_records_df': jump_records_df,       # Add to returned dict
            'detail_records_df': detail_records_df    # Add to returned dict
        }

    def get_iteration_summary(self):
        """获取迭代过程摘要"""
        if not self.iteration_records:
            return None
        
        summaries = [record['iteration_summary'] for record in self.iteration_records]
        return pd.DataFrame(summaries)

In [7]:
import pickle
import os

def save_gibbs_checkpoint(filename, recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, iteration):
    """保存Gibbs采样的断点（所有关键变量）"""
    checkpoint = {
        'recorder': recorder,
        'root_node': root_node,
        'path_list': path_list,
        'doc_path': doc_path,
        'doc_word_allocation': doc_word_allocation,
        'doc_node_allocation': doc_node_allocation,
        'iteration': iteration
    }
    with open(filename, 'wb') as f:
        pickle.dump(checkpoint, f)
    print(f"✅ checkpoint is saved to: {filename}（迭代 {iteration}）")

In [8]:
def load_gibbs_checkpoint(filename):
    """从断点文件恢复Gibbs采样的所有关键变量"""
    with open(filename, 'rb') as f:
        checkpoint = pickle.load(f)
    print(f"✅ from checkpoint file {filename} recover to iteration {checkpoint['iteration']}"),
    return (checkpoint['recorder'], checkpoint['root_node'], checkpoint['path_list'],
            checkpoint['doc_path'], checkpoint['doc_word_allocation'],
            checkpoint['doc_node_allocation'], checkpoint['iteration'])

In [9]:
# actually, this function is designed to be used after all iterations are done, to get stable documents across the last `window` iterations
# but we also can use it in a sliding window manner during iterations, if we input the windowed portion of doc_path_lst
def get_stable_docs_sliding_window(doc_path_lst_portion, back_window=5): # 迭代完成后计算
    """
    返回每一轮（从window轮开始）连续window轮路径都没变的文档ID集合
    返回格式：{迭代号: set(稳定文档ID)}
    """
    iter_portion = sorted(doc_path_lst_portion.keys())

    doc_ids = sorted(doc_path_lst_portion[iter_portion[-1]].keys())
    doc_idx_map = {doc_id: idx for idx, doc_id in enumerate(doc_ids)} # 这个是担心doc_ids不连续
    n_docs = len(doc_ids)
    n_iters = len(iter_portion) # 15
    # 构造完整路径矩阵 shape=(n_iters, n_docs)
    path_matrix = np.full((n_iters, n_docs), -1, dtype=np.int32)
    for i, iter_num in enumerate(iter_portion):
        for doc_id, leaf_id in doc_path_lst_portion[iter_num].items():
            path_matrix[i, doc_idx_map[doc_id]] = leaf_id
    # 滑动窗口统计
    stable_dict = {}
    # sliding window from 4 and review back from [0,5] to find stable docs
    for i in range(back_window-1, n_iters): # range from 4, and review back, is okay
        window_matrix = path_matrix[i-back_window+1:i+1, :]  # [0,5] rows # shape=(window, n_docs)
        first_row = window_matrix[0]
        stable_mask = np.all(window_matrix == first_row, axis=0) & (first_row != -1)
        stable_doc_ids = {doc_ids[j] for j in np.where(stable_mask)[0]}
        stable_dict[iter_portion[i]] = stable_doc_ids # 从第四个开始算稳定文档
    return stable_dict
## now, it can be used for check_inner_convergence

def get_jaccard_list_from_stable_dict(stable_dict, doc_node_allocation_lst_portion, back_window=5, depth=3):
    """
    返回每一轮与前一轮稳定文档集合的Jaccard相似度列表
    并统计窗口内每层每词最小词频值出现次数所占比例（最小值稳定性）
    """
    iters = sorted(stable_dict.keys()) # 从4开始
    jaccard_list = []
    stable_word_ratio_list = []
    for i in range(1, len(iters)):
        set1 = stable_dict[iters[i-1]]
        set2 = stable_dict[iters[i]]
        intersection = len(set1 & set2)
        union = len(set1 | set2)
        jaccard_index = intersection / union if union != 0 else 0
        jaccard_list.append(jaccard_index)

    for i in range(len(iters)):
        if doc_node_allocation_lst_portion is not None and iters[0] >= back_window-1: # 确保有足够的迭代次数
            stable_docs = stable_dict[iters[i]]
            min_freq_sum = 0
            total_freq_sum = 0
            for doc_id in stable_docs:
                window_allocs = [doc_node_allocation_lst_portion[j].get(doc_id, {}) for j in range(iters[i]-back_window+1, iters[i]+1)] # 窗口内的分配
                # print(f"window_allocs: {window_allocs}")
                for layer in range(depth): 
                    # 合并窗口内所有词
                    words = set()
                    for alloc in window_allocs:
                        words.update(alloc.get(layer, {}).keys())
                    for word in words:
                        freq_list = [alloc.get(layer, {}).get(word, 0) for alloc in window_allocs]
                        min_freq = min(freq_list)
                        min_freq_sum += min_freq
                        total_freq_sum += sum(freq_list)
            ratio = min_freq_sum / total_freq_sum if total_freq_sum > 0 else 0
            stable_word_ratio_list.append(ratio)
            # print(f"iter={iters[i]}, min_freq_sum={min_freq_sum}, total_freq_sum={total_freq_sum}, ratio={ratio}")
        else:
            print("⚠️ unable to calculate word ratio stability, missing doc_node_allocation_lst or not enough iterations")
            exit()
    return jaccard_list, stable_word_ratio_list

def calculate_gelman_rubin(chain_values, use_differences=False):
    """
    计算 Gelman-Rubin 统计量 (R-hat)
    
    Args:
        chain_values: 字典，键为链ID，值为该链的jaccard list
        warmup: 预热期比例，丢弃每条链前 warmup% 的样本
        
    Returns:
        float: Gelman-Rubin 统计量
    """
     # 处理原始数据
    if use_differences:
        # 计算差值（当前轮次减去前一轮次）
        diff_chains = {}
        for chain_id, values in chain_values.items():
            diff_values = [abs(values[i] - values[i-1]) for i in range(1, len(values))]
            diff_chains[chain_id] = diff_values
        chain_values = diff_chains

    chains = [chain for chain in chain_values.values()]
    
    n_chains = len(chains)
    n_samples = len(chains[0])
    
    # 计算链内均值
    chain_means = [np.mean(chain) for chain in chains]
    
    # 计算全局均值
    global_mean = np.mean(chain_means)
    
    # 计算链间方差 B
    between_chain_var = n_samples * np.sum([(mean - global_mean)**2 for mean in chain_means]) / (n_chains - 1)
    
    # 计算链内方差 W
    within_chain_var = np.mean([np.var(chain, ddof=1) for chain in chains])
    
    # 计算方差估计
    var_estimate = ((n_samples - 1) / n_samples) * within_chain_var + (1 / n_samples) * between_chain_var
    
    # 计算 R-hat
    r_hat = np.sqrt(var_estimate / within_chain_var)
    
    return r_hat

def check_convergence_with_gelman_rubin(chain_len_docs, mean_jaccard, mean_ratio, iteration, r_hat_threshold=1.1):
    """
    路径选择和稳定词分配ratio都要R-hat收敛才认为整体收敛
    """
    print(f"check_dual_convergence_with_gelman_rubin: chain_len_docs:{chain_len_docs}, mean_jaccard:{mean_jaccard}, mean_ratio:{mean_ratio}")
    r_hat_len_docs = calculate_gelman_rubin(chain_len_docs, use_differences=False)
    is_converged = (r_hat_len_docs < r_hat_threshold)
    
    convergence_info = {
        "iteration": iteration,
        "r_hat_len_docs": r_hat_len_docs,
        "mean_jaccard": mean_jaccard,
        "mean_ratio": mean_ratio,
        "threshold": r_hat_threshold,
        "converged": is_converged,
    }
    print(f"🥰 Convergence status: {'✅ Converged!' if is_converged else '❌ Not converged...'}")
    print(f"🥰  R-hat info: {r_hat_len_docs:.4f}，mean_jaccard:{mean_jaccard}, mean_ratio:{mean_ratio}")
    return is_converged, convergence_info

In [10]:
class SharedState:
    """多进程间共享状态管理类，支持Jaccard和ratio历史的存储与读取"""
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.state_file = os.path.join(base_dir, "shared_state.json")
        self.ensure_dir()
        self.init_state()

    def ensure_dir(self):
        os.makedirs(self.base_dir, exist_ok=True)

    def init_state(self):
        state = {
            "chains": {},
            "last_update": time.time(),
            "rhat_history": [],
            "convergence": False
        }
        self.save_state(state)

    def save_state(self, state):
        import json
        def convert_numpy_types(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, np.bool_):
                return bool(obj)
            elif isinstance(obj, dict):
                return {k: convert_numpy_types(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert_numpy_types(i) for i in obj]
            return obj
        converted_state = convert_numpy_types(state)
        with open(self.state_file, "w") as f:
            json.dump(converted_state, f)

    def load_state(self):
        import json
        try:
            with open(self.state_file, "r") as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            self.init_state()
            return self.load_state()

    def update_chain_data(self, chain_id, len_stable_docs, jaccard_list, ratio_list, iteration, record_time): # update a single chain's data
        """
        更新某条链的数据（Jaccard和ratio历史）
        """
        state = self.load_state()
        if str(chain_id) not in state["chains"]:
            state["chains"][str(chain_id)] = {
                "len_stable_docs":{},
                "jaccard": {},
                "ratio": {},
                "record_time": -1,
                "converged": False
            }
        chain_data = state["chains"][str(chain_id)]
        
        # 感觉如果是稳定路径数的话，倒也不用用till now的mean和std，因为数量比较好衡量是不是？
        chain_data["len_stable_docs"][str(record_time)] = len_stable_docs
        chain_data["jaccard"][str(record_time)] = jaccard_list
        chain_data["ratio"][str(record_time)] = ratio_list
        chain_data["record_time"] = record_time
        state["last_update"] = time.time()
        self.save_state(state)

    def update_rhat(self, rhat_info, iteration, converged=False):
        state = self.load_state()
        state["rhat_history"].append({
            "rhat_len_docs": rhat_info.get("r_hat_len_docs"),
            "mean_jaccard": rhat_info.get("mean_jaccard"),
            "mean_ratio": rhat_info.get("mean_ratio"),
            "rhat_threshold": rhat_info.get("r_hat_threshold"),
            "converged": converged,
            "iteration": iteration,
            "timestamp": time.time()
        })
        state["convergence"] = converged
        self.save_state(state)

    def update_convergence_status(self, converged):
        state = self.load_state()
        state["convergence"] = converged
        self.save_state(state)

    def get_chain_data_histories(self, record_time, max_retry=5, wait_sec=2):
        for _ in range(max_retry):
            state = self.load_state()
            len_stable_docs = {}
            jaccard_lists = {}
            ratio_lists = {}

            for chain_id, data in state["chains"].items():
                # 只取当前record_time对应的数据
                len_stable_docs[str(chain_id)] = data["len_stable_docs"][str(record_time)]
                jaccard_lists[str(chain_id)] = data["jaccard"][str(record_time)]
                ratio_lists[str(chain_id)] = data["ratio"][str(record_time)]

        return len_stable_docs, jaccard_lists, ratio_lists
    
    def get_latest_completed_record(self, n_chains): # find the min last_iteration across all chains
        state = self.load_state() # state is a dict from JSON file
        """
        state = {
        "chains": {},
        "last_update": time.time(),
        "rhat_history": [],
        "convergence": False
        }
        """
        if len(state["chains"]) < n_chains: # must have all chains
            return -1
        chain_record = []
        for data in state["chains"].values():
            chain_record.append(data.get("record_time", -1))
        return min(chain_record)

    def get_convergence_status(self):
        state = self.load_state()
        return state["convergence"]

    def all_chains_finished(self, n_chains):
        state = self.load_state()
        if len(state["chains"]) < n_chains:
            return False
        return all(data.get("converged", False) for data in state["chains"].values())


In [11]:
def rhat_monitor_process(shared_state, n_chains, max_iterations, burn_in, check_interval, back_window, r_hat_threshold=1.1):
    """
    监控Jaccard和ratio的R-hat，及时判断多链收敛
    shared_state: SharedState实例
    n_chains: 链数量
    max_iterations: 最大迭代次数
    r_hat_threshold: R-hat收敛阈值
    该进程持续运行，直到检测到收敛或所有链完成
    """
    print("🔍 Initialize R-hat independent monitor thread (stable docs & Jaccard & Ratio)")
    convergence_detected = False
    record_time = 0

    while True:
        # 其实这里就不需要了，因为_checkpoint已经做出判断了，而且很危险的是，如果monitor晚点检查就会被替代，所以要把update函数改为append
        current_record_time = shared_state.get_latest_completed_record(n_chains=n_chains)
        # last replaced lsts, iteration是向前回溯的
        
        if not current_record_time >= record_time:
            time.sleep(5)
            continue
        
        chain_len_docs, chain_jaccard_lists, chain_ratio_lists = shared_state.get_chain_data_histories(record_time = record_time)
        # print(f"chain_len_docs:{chain_len_docs}, chain_ratio_lists: {chain_ratio_lists}, chain_jaccard_lists: {chain_jaccard_lists}")
        
        mean_jaccard = [np.mean(jaccard_lst) for jaccard_lst in chain_jaccard_lists.values()]
        mean_ratio = [np.mean(ratio_lst) for ratio_lst in chain_ratio_lists.values()]
        # print(mean_jaccard, mean_ratio)                   
        
        current_last_iteration = min(burn_in+back_window+current_record_time*check_interval, max_iterations)
        is_converged, convergence_info = check_convergence_with_gelman_rubin(chain_len_docs=chain_len_docs, mean_jaccard=mean_jaccard, mean_ratio=mean_ratio, iteration=current_last_iteration, r_hat_threshold=r_hat_threshold)

        
        # print(f"chain_len_docs:{chain_len_docs}, chain_ratio_lists: {chain_ratio_lists}, chain_jaccard_lists: {chain_jaccard_lists}")
 
        """
        convergence_info = {
        "iteration": iteration,
        "r_hat_len_docs": r_hat_len_docs,
        "mean_jaccard": mean_jaccard,
        "mean_ratio": mean_ratio,
        "threshold": r_hat_threshold,
        "converged": is_converged,
    }
        """

        # 记录R-hat值
        shared_state.update_rhat(convergence_info, current_last_iteration, converged=is_converged)

        if is_converged:
            print(f"✅ r-hat of multi-chains is detected to convergence {burn_in+back_window+record_time*check_interval} in iteration:{current_last_iteration}, R-hat={convergence_info['r_hat_len_docs']:.4f}, Jaccard平均值={convergence_info['mean_jaccard']}, Ratio平均值={convergence_info['mean_ratio']}")
            convergence_detected = True
            shared_state.update_convergence_status(True)

        record_time += 1

        # 检查是否所有链都已完成
        if shared_state.all_chains_finished(n_chains) or current_last_iteration >= max_iterations:
            print("✊ All chains finished or max iterations reached, exiting r-hat monitor process.")
            break

        if convergence_detected:
            print("✅ R-hat convergence is detected, waiting for all chains to finish.")

        time.sleep(2)
    print("✅ R-hat monitor process finished.")

In [12]:
def run_gibbs_with_checkpoint(
    corpus, depth=3, gamma=None, eta=None, alpha=None,
    max_iterations=200, save_every_n_iter=10, checkpoint_dir='gibbs_checkpoints',
    check_interval=10, resume_from_checkpoint=None, base_dir=None,
    shared_state=None, chain_id=None, back_window=5, burn_in = 50):
    """
    支持断点保存与恢复的Gibbs采样主循环
    - resume_from_checkpoint: 断点文件路径（如需恢复则填写，否则从头开始)
    
     result = run_gibbs_with_checkpoint(
        corpus=corpus,
        depth=depth,
        gamma=gamma,
        eta=eta,
        alpha=alpha,
        max_iterations=max_iterations,
        checkpoint_dir=os.path.join(chain_dir, "checkpoints"),
        base_dir=chain_dir,
        chain_id=chain_id,
        shared_state=shared_state,
        back_window=back_window,
        burn_in=burn_in,  # 设置burn-in期为50轮
        check_interval=check_interval  # 设置滑动窗口大小
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    if resume_from_checkpoint is not None and os.path.exists(resume_from_checkpoint):
        # 从断点恢复
        (recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, start_iter) = load_gibbs_checkpoint(resume_from_checkpoint)
        print(f"Resume from checkpoint and continue sampling (from iteration {start_iter+1})")
    else:
        # 全新开始,
        print("📊 Initialize nCRP Process...")
        [root_node, path_list, doc_path, doc_word_allocation] = nCRP(
            corpus=corpus, depth=depth, gamma=gamma,
        )
        # path_list: {leaf_id: [Node1, Node2, ...]}
        # doc_path: {doc_id: leaf_id}
        # doc_word_allocation: {doc_id: [layer_for_word1, layer_for_word2,...]}
        
        recorder = Recorder(corpus, depth, eta, alpha) # for Recorder self._init_ setting
        doc_node_allocation = aggregate_words(corpus, doc_word_allocation)
        # doc_node_allocation: {doc_id: {layer: {word: count}}}
        
        global_node_word_dist = node_word_distribution(doc_node_allocation, doc_path, path_list, exclude_docs=None)
        # global_node_word_dist[node_id]: {layer: {word: count}}
        
        # 定义详细记录开始的迭代（与Gibbs采样中的iter_start一致）
        iter_start_for_log = burn_in + back_window

        initial_summary = recorder.record_iteration(
            0, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation,global_node_word_dist,
            iter_start_for_detailed_log=iter_start_for_log # 传递给初始记录
        ) # record the initial state
        """
        recorder_iteration return: some basic infomation
        iteration_summary = {
                    'iteration': iteration_num,
                    'total_paths': len(path_list),
                    'total_documents': len(doc_path),
                    'log_likelihood': log_likelihood,
                    'changed_docs_count': changed_docs_count, 
                    'newly_created_paths': len(newly_created_paths or []),
                    'avg_path_size',
                    'max_path_size',
                    'min_path_size'}
        """
        print(f"📝 Chain {chain_id} initial state after nCRP is recorded: {initial_summary['total_paths']} path, ", 
                f"Log-likelihood: {initial_summary['log_likelihood']:.2f}")
        start_iter = 1

    previous_log_likelihood = recorder.iteration_records[-1]['iteration_summary']['log_likelihood']
    loglikelihood_list = [previous_log_likelihood]
    change_docs_list = []

    doc_path_lst = {0:doc_path.copy()}# 用于存储每条链的jaccard历史记录
    doc_node_allocation_lst = {0:doc_node_allocation.copy()} # 用于存储每条链的doc_node_allocation历史记录

    record_time = 0
    start_window = burn_in + back_window 
    
    for iteration in range(start_iter, max_iterations + 1): # start_iter=1
        old_paths = set(path_list.keys())
        doc_path_lst[iteration] = doc_path.copy()
        
        # 定义详细记录开始的迭代（与Gibbs采样中的iter_start一致）
        # 也用于控制Recorder的详细日志记录           
        current_iter_start_for_detailed_log = burn_in + back_window

        jumpy_record, detailed_record = Gibbs_sampling(
            corpus, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, global_node_word_dist,
            gamma, eta, alpha, depth, iteration, 
            iter_start=current_iter_start_for_detailed_log, # Gibbs采样内部使用
            iter_end=max_iterations
        )
        """
        detail_record.items() = {
            'iteration': iteration,
            'doc_id': doc_id,
            'deleted_path_list': {leaf_id: [node.node_id for node in path_nodes] for leaf_id, path_nodes in path_list.items()}, # 注意深拷贝的性能影响
            'doc_path': {k:v for k, v in doc_path.items()},
            'doc_word_allocation': list(doc_word_allocation[doc_id]),
            'doc_node_allocation': {layer: {word: count for word, count in word_counts.items()} 
                                    for layer, word_counts in doc_node_allocation[doc_id].items()}
        }

        jump_record.items() = {
            'iteration': iteration,
            'doc_id': doc_id,
            'old_leaf_id': current_path[-1].node_id,
            'old_path': [n.node_id for n in current_path],
            'new_leaf_id': added_path[-1].node_id,
            'new_path': [n.node_id for n in added_path],
            'origal_probs':all_probs,
            'normalized_probs': normalized_probs,
            'chosen_path_prob': chosen_path_prob,
            'rank': f'{rank} out of {len(values)}',
            'create_path': True
        }
        """

        doc_path_lst[iteration] = doc_path.copy() # 记录当前迭代的doc_path，用于后续计算jaccard
        doc_node_allocation_lst[iteration] = doc_node_allocation.copy() # 记录当前迭代的doc_node_allocation

        new_paths = set(path_list.keys()) - old_paths
        iteration_summary = recorder.record_iteration(
            iteration, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, global_node_word_dist,
            jumpy_record, detailed_record, new_paths,
            iter_start_for_detailed_log=current_iter_start_for_detailed_log # 传递给Recorder
        )
        print(f"🔄 Chain {chain_id} in iteration {iteration}/{max_iterations},\n",
                f"📊 path: {iteration_summary['total_paths']}, ",
                f"new path: {iteration_summary['newly_created_paths']}, ",
                f"docs changed path: {iteration_summary['changed_docs_count']}, ",
                f"Log-likelihood: {iteration_summary['log_likelihood']:.2f}")
        
        # 断点保存
        # if (iteration % save_every_n_iter == 0) or (iteration == max_iterations):
        #     checkpoint_path = os.path.join(checkpoint_dir, f'gibbs_checkpoint_iter{iteration}.pkl')
        #     save_gibbs_checkpoint(
        #         checkpoint_path, recorder, root_node, path_list, doc_path,
        #         doc_word_allocation, doc_node_allocation, iteration,
        #     )

        # 检查收敛
        change_docs_list.append(iteration_summary['changed_docs_count'])
        loglikelihood_list.append(iteration_summary['log_likelihood'])
        
        # iteration starts from 1, burn_in + back_window + check_interval 
        # print((iteration - start_window) / check_interval)
        if (iteration - start_window) / check_interval == 1 or  (iteration == max_iterations): # iteration starts from 1
            # 每次迭代开始时检查是否已收敛（快速检查，无需重新计算）
            convergence_status = shared_state.get_convergence_status() if shared_state else False
            if convergence_status:
                final_iteration = iteration
                print(f"🎉 Chain {chain_id} is detected to convergence in iteration {iteration}, sampling ends prematurely.")
                break
                
            """    
            check_interval=5,
            back_window=3,
            burn_in=2,
            """
            window_keys = list(range(iteration-check_interval-back_window+1, iteration+1))
            # print(f"window_keys:{window_keys}")
            
            doc_path_window = {k: doc_path_lst[k] for k in window_keys} # 6-20
            
            stable_docs = get_stable_docs_sliding_window(doc_path_window, back_window=back_window) # interval+1轮的stable docs
            chain_len_docs = [len(doc_list) for iterd, doc_list in stable_docs.items()]
            
            doc_node_allocation_window = {k: doc_node_allocation_lst[k] for k in window_keys}
            
            jaccard_lst, ratio_lst = get_jaccard_list_from_stable_dict(stable_docs, doc_node_allocation_window, back_window=back_window, depth=depth)
            
            shared_state.update_chain_data(chain_id, chain_len_docs, jaccard_lst, ratio_lst, iteration, record_time)
            start_window = iteration
            record_time += 1
            
            # print(f"stable_docs:{stable_docs},jaccard_lst{jaccard_lst}, ratio_lst{ratio_lst}")
            
            """ def update_chain_data(self, chain_id, len_stable_docs, jaccard_list, ratio_list, iteration, record_time): # update a single chain's data"""
    
    print(f"🎯 Chain {chain_id} finish Gibbs Sampling.")
    print("💾 Save iterations info...")
    saved_files = recorder.save_to_files(base_filename=os.path.join(base_dir, "iteration") if base_dir else "iteration")
    return recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, saved_files, change_docs_list, loglikelihood_list, doc_path_lst, doc_node_allocation_lst

In [13]:
def calculate_renyi_entropy(prob_dist, q=2.0):
    total = sum(prob_dist.values())
    if total <= 0:
        return 0.0
    moment_q = 0.0
    for v in prob_dist.values():
        p = v / total
        if p > 0:
            moment_q += p ** q
    if moment_q <= 0:
        return float('inf')
    return (1.0 / (1.0 - q)) * math.log(moment_q)  # 自然对数

def jensen_shannon_divergence(dist1, dist2):
    # 统一词表并归一化
    keys = set(dist1.keys()) | set(dist2.keys())
    s1 = float(sum(dist1.get(k, 0.0) for k in keys))
    s2 = float(sum(dist2.get(k, 0.0) for k in keys))
    if s1 == 0 or s2 == 0:
        return 1.0
    p = {k: dist1.get(k, 0.0) / s1 for k in keys}
    q = {k: dist2.get(k, 0.0) / s2 for k in keys}
    m = {k: 0.5 * (p[k] + q[k]) for k in keys}

    def _kl(a, b):
        val = 0.0
        for k in keys:
            ak = a[k]
            bk = b[k]
            if ak > 0 and bk > 0:
                val += ak * math.log(ak / bk)
        return val

    return 0.5 * _kl(p, m) + 0.5 * _kl(q, m)

def evaluate_tree_structure_with_nodes(root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation):
    """
    返回：
      - node_records: 每节点一行（layer/node_id/parent_id/entropy/doc_count/path）
      - layer_entropy_wavg: {layer: 文档数加权的Rényi熵}
      - layer_distinctiveness_wavg: {layer: 文档数加权JSD(主题异质性)}
      - nodes_per_layer: {layer: 节点数}
    说明：
      - 节点文档数：统计落在该节点（该层）上的去重文档数
      - 加权平均：用成对权重 m_i*m_j（异质性）与节点权重 m_i（熵）
    """
    # 词频与文档集合
    layer_word_dist = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))  # layer -> node_id -> word -> cnt
    node_doc_sets = defaultdict(set)  # node_id -> {doc_ids}

    # 若 root_node 是 dict[node_id]->obj 且含 parent_id，用它；否则 parent_id 置 None
    get_parent_id = (lambda nid: getattr(root_node.get(nid), 'parent_id', None)) if isinstance(root_node, dict) else (lambda nid: None)

    # 累积每节点的词频与文档集合
    for doc_id, leaf_id in doc_path.items():
        path_nodes = path_list[leaf_id]  # [Node1, Node2, ...]
        per_doc_alloc = doc_node_allocation.get(doc_id, {})
        for node in path_nodes:
            lyr = node.layer
            if lyr in per_doc_alloc:
                # 记录词频
                for w, c in per_doc_alloc[lyr].items():
                    layer_word_dist[lyr][node.node_id][w] += c
                # 记录文档
                node_doc_sets[node.node_id].add(doc_id)

#     # 汇总逐节点
#     node_records = []
#     for layer, nodes_dist in layer_word_dist.items():
#         for nid, wdist in nodes_dist.items():
#             entropy = calculate_renyi_entropy(wdist, q=2.0)
#             doc_count = len(node_doc_sets[nid])
#             node_records.append({
#                 "layer": layer,
#                 "node_id": nid,
#                 "parent_id": get_parent_id(nid),
#                 "entropy": entropy,
#                 "doc_count": doc_count,
#                 # 可选：保存路径字符串（如需要，你可以在外部预先构造 node_id->path_str 的映射传入）
#                 "path": None
#             })
            
    # 在 evaluate_tree_structure 函数中添加调试信息
    node_records = []
    for layer, nodes_dist in layer_word_dist.items():
        for nid, wdist in nodes_dist.items():
            entropy = calculate_renyi_entropy(wdist, q=2.0)
            doc_count = len(node_doc_sets[nid]) if nid in node_doc_sets else 0

            # 添加调试信息
            if entropy == 0:
                print(f"⚠️ Node {nid} (Layer {layer}) has entropy=0:")
                print(f"   Word distribution: {dict(wdist)}")
                print(f"   Number of unique words: {len(wdist)}")
                print(f"   Total word count: {sum(wdist.values())}")
                print(f"   Document count: {doc_count}")
                print()

            node_records.append({
                "layer": layer,
                "node_id": nid,
                "parent_id": get_parent_id(nid),
                "entropy": entropy,
                "doc_count": doc_count,
                "unique_words": len(wdist),
                "total_words": sum(wdist.values()),
                "path": None
            })
        
    # 层级文档加权熵
    layer_entropy_wavg = {}
    nodes_per_layer = {}
    for layer, nodes_dist in layer_word_dist.items():
        nodes = list(nodes_dist.keys())
        nodes_per_layer[layer] = len(nodes)
        total_docs = sum(len(node_doc_sets[nid]) for nid in nodes)
        if total_docs == 0:
            layer_entropy_wavg[layer] = 0.0
            continue
        wsum = 0.0
        for nid in nodes:
            H = calculate_renyi_entropy(nodes_dist[nid], q=2.0)
            w = len(node_doc_sets[nid])
            wsum += H * w
        layer_entropy_wavg[layer] = wsum / total_docs

    # 层级文档加权主题异质性（加权JSD）
    layer_distinctiveness_wavg = {}
    for layer, nodes_dist in layer_word_dist.items():
        nids = list(nodes_dist.keys())
        if len(nids) < 2:
            layer_distinctiveness_wavg[layer] = 0.0
            continue
        wsum_jsd = 0.0
        wsum = 0.0
        for i in range(len(nids)):
            for j in range(i+1, len(nids)):
                ni, nj = nids[i], nids[j]
                mi, mj = len(node_doc_sets[ni]), len(node_doc_sets[nj])
                if mi == 0 or mj == 0:
                    continue
                jsd = jensen_shannon_divergence(nodes_dist[ni], nodes_dist[nj])
                w = mi * mj
                wsum_jsd += jsd * w
                wsum += w
        layer_distinctiveness_wavg[layer] = (wsum_jsd / wsum) if wsum > 0 else 0.0

    return {
        "node_records": node_records,
        "layer_entropy_wavg": layer_entropy_wavg,
        "layer_distinctiveness_wavg": layer_distinctiveness_wavg,
        "nodes_per_layer": nodes_per_layer
    }

In [14]:
def _run_single_chain(args):
    """
    运行单条 hLDA 链的函数（用于 joblib 并行）
    
    Args:
        args: 元组 (chain_id, corpus, depth, gamma, eta, alpha, max_iterations, general_dir, seed)
        
    Returns:
        dict: 包含链结果的字典
    """
    (chain_id, corpus, depth, gamma, eta, alpha, max_iterations, general_dir, seed, shared_state, back_window, check_interval, burn_in) = args

    print(f"⛓️ Chain {chain_id} starts（PID: {os.getpid()}）")
    
    # 设置随机种子
    np.random.seed(seed)
    
    # 为链创建目录
    chain_dir = os.path.join(general_dir, f"depth_{depth}_gamma_{gamma}_run_{chain_id}")
    os.makedirs(chain_dir, exist_ok=True)
    
    # 运行Gibbs采样
    """
    def run_gibbs_with_checkpoint(
    corpus, depth=3, gamma=None, eta=None, alpha=None,
    max_iterations=200, save_every_n_iter=10, checkpoint_dir='gibbs_checkpoints',
    check_interval=10, resume_from_checkpoint=None, base_dir=None,
    shared_state=None, chain_id=None, back_window=5, burn_in = 50):
    """
    result = run_gibbs_with_checkpoint(
        corpus=corpus,
        depth=depth,
        gamma=gamma,
        eta=eta,
        alpha=alpha,
        max_iterations=max_iterations,
        checkpoint_dir=os.path.join(chain_dir, "checkpoints"),
        base_dir=chain_dir,
        chain_id=chain_id,
        shared_state=shared_state,
        back_window=back_window,
        burn_in=burn_in,  # 设置burn-in期为50轮
        check_interval=check_interval  # 设置滑动窗口大小
    )
    # 1. converged: return recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, saved_files
    # 2. not converged: return recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, saved_files, change_docs_list, loglikelihood_list, converge_or_not, doc_path_lst, doc_node_allocation_lst
    
    # 提取结果
    recorder, root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation, saved_files, change_docs_list, loglikelihood_list, doc_path_lst, doc_node_allocation_lst = result
    
    # 保存轻量级结果到磁盘以便主进程读取
    result_file = os.path.join(chain_dir, "final_checkpoint.pkl")
    with open(result_file, 'wb') as f:
        chain_result = {
            'chain_id': chain_id,
            'loglikelihood_history': loglikelihood_list,
            'changed_docs_history': change_docs_list
            # 不保存太大的对象
        }
        pickle.dump(chain_result, f)
    
    print(f"✅ Chain {chain_id} Finished !")
    
    # 1. 按层保存
    
    res = evaluate_tree_structure_with_nodes(root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation)
    
    layers = sorted(res["layer_entropy_wavg"].keys())
    layer_rows = []
    for L in layers:
        layer_rows.append({
            "depth": depth,
            "gamma": gamma,
            "eta": eta,
            "alpha": alpha,
            "layer": L,
            "entropy_wavg": res["layer_entropy_wavg"].get(L, 0.0),
            "distinctiveness_wavg_jsd": res["layer_distinctiveness_wavg"].get(L, 0.0),
            "nodes_in_layer": res["nodes_per_layer"].get(L, 0)
        })
    df_layers = pd.DataFrame(layer_rows)
    layers_csv = os.path.join(chain_dir, "result_layers.csv")
    os.makedirs(chain_dir, exist_ok=True)
    if not os.path.exists(layers_csv):
        df_layers.to_csv(layers_csv, index=False, mode='w', header=True)
    else:
        df_layers.to_csv(layers_csv, index=False, mode='a', header=False)

    # 2. 按节点保存
    df_nodes = pd.DataFrame(res["node_records"])
    if not df_nodes.empty:
        df_nodes.insert(0, "depth", depth)
        df_nodes.insert(1, "gamma", gamma)
        df_nodes.insert(2, "eta", eta)
        df_nodes.insert(3, "alpha", alpha)
        nodes_csv = os.path.join(chain_dir, "result_nodes.csv")
        if not os.path.exists(nodes_csv):
            df_nodes.to_csv(nodes_csv, index=False, mode='w', header=True)
        else:
            df_nodes.to_csv(nodes_csv, index=False, mode='a', header=False)

    # 3. 单行摘要（原来的 result_metrics.csv ）
    summary = {
        "depth": depth,
        "gamma": gamma,
        "eta": eta,
        "alpha": alpha,
        "avg_entropy_wavg_over_layers": float(np.mean([res["layer_entropy_wavg"][L] for L in layers])) if layers else 0.0,
        "avg_distinctiveness_wavg_over_layers": float(np.mean([res["layer_distinctiveness_wavg"][L] for L in layers])) if layers else 0.0,
        "total_layers": len(layers),
        "total_nodes": int(sum(res["nodes_per_layer"].values())) if layers else 0
    }
    df_metrics = pd.DataFrame([summary])
    metrics_csv = os.path.join(chain_dir, "result_metrics.csv")
    if not os.path.exists(metrics_csv):
        df_metrics.to_csv(metrics_csv, index=False, mode='w', header=True)
    else:
        df_metrics.to_csv(metrics_csv, index=False, mode='a', header=False)

#     # 计算指标并保存
#     metrics = evaluate_tree_structure(root_node, path_list, doc_path, doc_word_allocation, doc_node_allocation)

#     metrics.update({
#         'gamma': gamma,
#         'eta': eta,
#         'depth': depth,
#         'alpha': alpha,
#     })     
    
#     metrics_df = pd.DataFrame([metrics])  # 将字典转换为DataFrame
#     metrics_df.to_csv(os.path.join(chain_dir, "result_metrics.csv"), index=False)
    
    # 返回轻量级结果
    return {
        'chain_id': chain_id,
        'loglikelihood_history': loglikelihood_list,
        'result_file': result_file,
        'changed_docs_history': change_docs_list,
        'doc_path_lst': doc_path_lst,
        'doc_node_allocation_lst': doc_node_allocation_lst
    }

In [15]:
from threading import Thread
import json

def run_multi_chain_hlda(
    corpus, depth=3, gamma=0.01, eta=0.01, alpha=0.1,
    n_chains=3, max_iterations=20, r_hat_threshold=1.1,
    general_dir="multi_chain_hlda_results",
    back_window=5, check_interval=10, burn_in=50
):
    """
    并行运行多条 hLDA 链，收敛检查由 monitor 线程实时完成。
    运行结束后自动保存所有收敛检查历史数据为 CSV。
    
    result = run_multi_chain_hlda(
    corpus=corpus,
    depth=depth,
    gamma=gamma,
    eta=eta,
    alpha=alpha,
    n_chains=n_chains,
    max_iterations=30,
    check_interval=5,
    back_window=3,
    burn_in=2,
    r_hat_threshold=1.1,
    general_dir="0809_multi_len_docs")
    """

    os.makedirs(general_dir, exist_ok=True)

    # 启动共享状态和监控线程
    shared_state = SharedState(general_dir)
    monitor = Thread(
        target=rhat_monitor_process,
        args=(shared_state, n_chains, max_iterations, burn_in, r_hat_threshold, check_interval)
    )
    monitor.daemon = False
    monitor.start()
    print("🚀 Start R-hat monitor process")

    # 构造每条链的参数
    args_list = []
    for i in range(1, n_chains + 1):
        seed = i * 1000 + int(time.time()) % 1000
        args_list.append((
            i, corpus, depth, gamma, eta, alpha,
            max_iterations, general_dir, seed,
            shared_state, back_window, check_interval, burn_in
        ))
        
    """
    (chain_id, corpus, depth, gamma, eta, alpha, max_iterations, general_dir, seed, shared_state, back_window, check_interval, burn_in) = args
    """

    print(f"🚀 Start {n_chains} hLDA chain，each chain could have {max_iterations} max iterations...")

    # 并行运行所有链
    chain_results = Parallel(n_jobs=n_chains, backend='multiprocessing', verbose=10)(
        delayed(_run_single_chain)(args) for args in args_list
    )

    print("✅ All chains finish sampling and waiting for monitor process finish...")

    # 等待监控线程结束（即所有链都收敛或达到最大迭代）
    monitor.join(timeout=60)
    print("✅ R-hat monitor process finished !")

    # 收集每条链的结果
    full_results = {}
    for result in chain_results:
        with open(result['result_file'], 'rb') as f:
            full_results[result['chain_id']] = pickle.load(f)

    # 保存收敛检查历史数据
    rhat_file = os.path.join(general_dir, "shared_state.json")
    if os.path.exists(rhat_file):
        with open(rhat_file, "r") as f:
            shared_state_data = json.load(f)
        
        # 1. 保存总体R-hat收敛历史到general_dir
        rhat_history = shared_state_data.get("rhat_history", [])
        if rhat_history:
            convergence_df = pd.DataFrame(rhat_history)
            convergence_csv = os.path.join(general_dir, "convergence_info.csv")
            convergence_df.to_csv(convergence_csv, index=False, encoding='utf-8')
            print(f"✅ Overall convergence history is saved to CSV: {convergence_csv}")
        else:
            print("⚠️ Does not find any convergence history in shared_state.json.")
        
        # 2. 为每个链单独保存其收敛信息到各自的chain_dir
        for result in chain_results:
            chain_id = result['chain_id']
            chain_dir = os.path.join(general_dir, f"depth_{depth}_gamma_{gamma}_run_{chain_id}")
            
            # 提取该链特定的数据
            chain_data = shared_state_data.get("chains", {}).get(str(chain_id), {})
            if chain_data:
                # 转换为DataFrame格式方便分析
                chain_records = []
                for record_time, len_docs in chain_data.get('len_stable_docs', {}).items():
                    jaccard_list = chain_data.get('jaccard', {}).get(record_time, [])
                    ratio_list = chain_data.get('ratio', {}).get(record_time, [])
                    
                    chain_records.append({
                        'record_time': int(record_time),
                        'len_stable_docs': len_docs,
                        'jaccard_mean': np.mean(jaccard_list) if jaccard_list else 0,
                        'jaccard_std': np.std(jaccard_list) if jaccard_list else 0,
                        'ratio_mean': np.mean(ratio_list) if ratio_list else 0,
                        'ratio_std': np.std(ratio_list) if ratio_list else 0,
                        'jaccard_list': jaccard_list,
                        'ratio_list': ratio_list
                    })
                
                if chain_records:
                    chain_df = pd.DataFrame(chain_records)
                    chain_convergence_csv = os.path.join(chain_dir, "chain_convergence_info.csv")
                    chain_df.to_csv(chain_convergence_csv, index=False, encoding='utf-8')
                    print(f"✅ Chain {chain_id} convergence info saved to: {chain_convergence_csv}")
                else:
                    print(f"⚠️ No convergence data found for chain {chain_id}")
            else:
                print(f"⚠️ No chain data found for chain {chain_id}")
    else:
        print("⚠️ Does not find shared_state.json, no convergence history saved.")

    return full_results

In [16]:
import cProfile
# 设置参数（使用小写变量名）
depth = 4
gamma = 0.001
eta = 0.1
alpha = 0.1
n_chains = 3  # 3条链并行运行

# 使用 joblib 运行多链
result = run_multi_chain_hlda(
    corpus=corpus,
    depth=depth,
    gamma=gamma,
    eta=eta,
    alpha=alpha,
    n_chains=n_chains,
    max_iterations=500,
    check_interval=20,
    back_window=5,
    burn_in=50,
    r_hat_threshold=1.1,
    general_dir="machine1_step1_d4_g0001")

🔍 Initialize R-hat independent monitor thread (stable docs & Jaccard & Ratio)
🚀 Start R-hat monitor process
🚀 Start 3 hLDA chain，each chain could have 500 max iterations...
⛓️ Chain 1 starts（PID: 36192）
📊 Initialize nCRP Process...
⛓️ Chain 2 starts（PID: 36193）
📊 Initialize nCRP Process...
⛓️ Chain 3 starts（PID: 36194）
📊 Initialize nCRP Process...


[Parallel(n_jobs=3)]: Using backend MultiprocessingBackend with 3 concurrent workers.


📝 Chain 1 initial state after nCRP is recorded: 1 path,  Log-likelihood: -545285.38
📝 Chain 3 initial state after nCRP is recorded: 1 path,  Log-likelihood: -545220.96
📝 Chain 2 initial state after nCRP is recorded: 1 path,  Log-likelihood: -545350.17
🔄 Chain 2 in iteration 1/500,
 📊 path: 58,  new path: 57,  docs changed path: 172,  Log-likelihood: -514533.75
🔄 Chain 1 in iteration 1/500,
 📊 path: 59,  new path: 58,  docs changed path: 171,  Log-likelihood: -516097.04
🔄 Chain 3 in iteration 1/500,
 📊 path: 67,  new path: 66,  docs changed path: 175,  Log-likelihood: -513028.43
🔄 Chain 1 in iteration 2/500,
 📊 path: 96,  new path: 47,  docs changed path: 173,  Log-likelihood: -504549.36
🔄 Chain 2 in iteration 2/500,
 📊 path: 93,  new path: 55,  docs changed path: 182,  Log-likelihood: -503195.56
🔄 Chain 3 in iteration 2/500,
 📊 path: 92,  new path: 43,  docs changed path: 158,  Log-likelihood: -503246.47
🔄 Chain 1 in iteration 3/500,
 📊 path: 98,  new path: 41,  docs changed path: 179,

[Parallel(n_jobs=3)]: Done   3 out of   3 | elapsed: 229.0min finished


✅ All chains finish sampling and waiting for monitor process finish...
✅ R-hat monitor process finished !
✅ Overall convergence history is saved to CSV: machine1_step1_d4_g0001/convergence_info.csv
✅ Chain 1 convergence info saved to: machine1_step1_d4_g0001/depth_4_gamma_0.001_run_1/chain_convergence_info.csv
✅ Chain 2 convergence info saved to: machine1_step1_d4_g0001/depth_4_gamma_0.001_run_2/chain_convergence_info.csv
✅ Chain 3 convergence info saved to: machine1_step1_d4_g0001/depth_4_gamma_0.001_run_3/chain_convergence_info.csv
